diff --git a/pkg/document/document.go b/pkg/document/document.go index a58011b..f7b7d60 100644 --- a/pkg/document/document.go +++ b/pkg/document/document.go @@ -15,6 +15,11 @@ import ( "github.com/zerx-lab/wordZero/pkg/style" ) +const ( + wordprocessingMLNamespace = "http://schemas.openxmlformats.org/wordprocessingml/2006/main" + wordprocessingMLStrictNamespace = "http://purl.oclc.org/ooxml/wordprocessingml/main" +) + // Document 表示一个Word文档 type Document struct { // 文档的主要内容 @@ -2016,21 +2021,42 @@ func (d *Document) parseDocument() error { switch t := token.(type) { case xml.StartElement: - if t.Name.Local == "document" && t.Name.Space == "http://schemas.openxmlformats.org/wordprocessingml/2006/main" { + if isWordprocessingDocumentElement(t) { // 开始解析文档 if err := d.parseDocumentElement(decoder); err != nil { return err } goto done } + if t.Name.Local == "document" { + return WrapError("parse_document", fmt.Errorf("%w: unsupported document namespace %q", ErrInvalidDocument, t.Name.Space)) + } } } + return WrapError("parse_document", fmt.Errorf("%w: missing document root", ErrInvalidDocument)) + done: + if d.Body == nil { + return WrapError("parse_document", fmt.Errorf("%w: missing document body", ErrInvalidDocument)) + } Infof("解析完成,共 %d 个元素", len(d.Body.Elements)) return nil } +func isWordprocessingDocumentElement(start xml.StartElement) bool { + if start.Name.Local != "document" { + return false + } + + switch start.Name.Space { + case wordprocessingMLNamespace, wordprocessingMLStrictNamespace: + return true + default: + return false + } +} + // parseDocumentElement 解析文档元素 func (d *Document) parseDocumentElement(decoder *xml.Decoder) error { // 初始化Body diff --git a/pkg/document/document_test.go b/pkg/document/document_test.go index c8e54f3..f044f57 100644 --- a/pkg/document/document_test.go +++ b/pkg/document/document_test.go @@ -1,6 +1,10 @@ package document import ( + "archive/zip" + "bytes" + "errors" + "io" "os" "testing" @@ -958,6 +962,77 @@ func TestDocumentOpenFromMemory(t *testing.T) { } } +func TestDocumentOpenFromMemoryWithStrictNamespace(t *testing.T) { + files := buildTestDocxReader(t, ` + + + + + Strict 文档 + + + +`) + + doc, err := OpenFromMemory(files) + if err != nil { + t.Fatalf("Failed to open strict namespace document: %v", err) + } + + paragraphs := doc.Body.GetParagraphs() + if len(paragraphs) != 1 { + t.Fatalf("Expected 1 paragraph, got %d", len(paragraphs)) + } + + if got := paragraphs[0].Runs[0].Text.Content; got != "Strict 文档" { + t.Fatalf("Expected strict namespace text to be parsed, got %q", got) + } +} + +func TestDocumentOpenFromMemoryWithInvalidRootReturnsError(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("OpenFromMemory should return an error instead of panicking: %v", r) + } + }() + + files := buildTestDocxReader(t, ` + + +`) + + _, err := OpenFromMemory(files) + if err == nil { + t.Fatal("Expected invalid document root to return an error") + } + + if !errors.Is(err, ErrInvalidDocument) { + t.Fatalf("Expected ErrInvalidDocument, got %v", err) + } +} + +func buildTestDocxReader(t *testing.T, documentXML string) io.ReadCloser { + t.Helper() + + var buf bytes.Buffer + zipWriter := zip.NewWriter(&buf) + + fileWriter, err := zipWriter.Create("word/document.xml") + if err != nil { + t.Fatalf("Failed to create document.xml entry: %v", err) + } + + if _, err := fileWriter.Write([]byte(documentXML)); err != nil { + t.Fatalf("Failed to write document.xml entry: %v", err) + } + + if err := zipWriter.Close(); err != nil { + t.Fatalf("Failed to close test DOCX writer: %v", err) + } + + return io.NopCloser(bytes.NewReader(buf.Bytes())) +} + // TestAddPageBreak 测试添加分页符功能 func TestAddPageBreak(t *testing.T) { doc := New()