diff --git a/enumify.go b/enumify.go index 5b4eb93..1b89971 100644 --- a/enumify.go +++ b/enumify.go @@ -154,6 +154,7 @@ func Generate(opts Options) (err error) { typeID := etype.Id() // Enum namesVar := etype.NamesVarId() // enumNames namesVarSlice := etype.NamesVarSliceId() // enumNames or enumNames[0] + parserID := g.Id("Parse" + etype.Name) // ParseEnum f.Comment("//============================================================================") f.Comment("// " + etype.Name + " Enum Type: Generated Functions and Methods") @@ -197,6 +198,57 @@ func Generate(opts Options) (err error) { f.Line() } + // Text Marshal and Unmarshal code + if !opts.NoText { + f.Commentf("Ensure %s implements text.Marshaler.", etype.Name) + method := methodSig.Clone().Id("MarshalText").Call().Params(g.Id("[]byte"), g.Error()) + method.Block( + g.Return(g.Id("[]byte").Call(s.Clone().Dot("String").Call()), g.Nil()), + ) + f.Add(method) + f.Line() + + f.Commentf("Ensure %s implements text.Unmarshaler.", etype.Name) + method = methodPtrSig.Clone().Id("UnmarshalText").Call(g.Id("data").Id("[]byte")).Params(evar.Clone().Error()) + method.Block( + ptrS.Clone().Op(",").Add(evar).Op("=").Add(parserID.Clone()).Call(g.String().Call(g.Id("data"))), + g.Return(evar), + ) + f.Add(method) + f.Line() + } + + // Binary Marshal and Unmarshal code + if !opts.NoBinary { + f.Commentf("Ensure %s implements binary.Marshaler.", etype.Name) + method := methodSig.Clone().Id("MarshalBinary").Call().Params(g.Id("[]byte"), g.Error()) + method.Block( + g.Return(g.Id("[]byte").Values(g.Byte().Call(s)), g.Nil()), + ) + f.Add(method) + f.Line() + + f.Commentf("Ensure %s implements binary.Unmarshaler.", etype.Name) + method = methodPtrSig.Clone().Id("UnmarshalBinary").Call(g.Id("data").Id("[]byte")).Params(evar.Clone().Error()) + method.Block( + g.Switch(g.Len(g.Id("data"))).Block( + g.Case(g.Lit(0)).Block( + ptrS.Clone().Op("=").Add(typeID.Clone().Call(g.Lit(0))), + g.Return(g.Nil()), + ), + g.Case(g.Lit(1)).Block( + ptrS.Clone().Op("=").Add(typeID.Clone().Call(g.Id("data").Index(g.Lit(0)))), + g.Return(g.Nil()), + ), + g.Default().Block( + g.Return(g.Qual("fmt", "Errorf").Call(g.Lit("cannot unmarshal %d bytes into "+etype.Name), g.Len(g.Id("data")))), + ), + ), + ) + f.Add(method) + f.Line() + } + // JSON Marshal and Unmarshal code if !opts.NoJSON { f.Commentf("Ensure %s implements json.Marshaler.", etype.Name) @@ -220,7 +272,7 @@ func Generate(opts Options) (err error) { ), g.Line(), g.If( - ptrS.Clone().Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil(). + ptrS.Clone().Op(",").Id("err").Op("=").Add(parserID.Clone()).Call(sv).Op(";").Id("err").Op("!=").Nil(). Block( g.Return(evar), ), @@ -251,7 +303,7 @@ func Generate(opts Options) (err error) { g.Return(evar), ), g.Line(), - g.If(g.Add(ptrS.Clone()).Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil(). + g.If(g.Add(ptrS.Clone()).Op(",").Id("err").Op("=").Add(parserID.Clone()).Call(sv).Op(";").Id("err").Op("!=").Nil(). Block( g.Return(evar), ), @@ -271,11 +323,11 @@ func Generate(opts Options) (err error) { g.Switch(val.Clone().Op(":=").Id("src").Assert(g.Type())).Block( g.Case(g.Nil()).Block(g.Return(g.Nil())), g.Case(g.String()).Block( - g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(val)), + g.List(ptrS.Clone(), evar.Clone().Op("=").Add(parserID.Clone()).Call(val)), g.Return(evar), ), g.Case(g.Id("[]byte")).Block( - g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(g.String().Call(val))), + g.List(ptrS.Clone(), evar.Clone().Op("=").Add(parserID.Clone()).Call(g.String().Call(val))), g.Return(evar), ), g.Default().Block( diff --git a/parse.go b/parse.go index 791b38b..d6b26d1 100644 --- a/parse.go +++ b/parse.go @@ -1,6 +1,7 @@ package enumify import ( + "errors" "fmt" "strings" ) @@ -47,6 +48,11 @@ func ParseFactory[T ~uint8, Names []string | [][]string](names Names) Parser[T] unknown := T(0) return func(val any) (T, error) { + // Convert byte slices to strings. + if v, ok := val.([]byte); ok && len(v) > 1 { + val = string(v) + } + switch v := val.(type) { case string: v = normalize(v) @@ -65,6 +71,15 @@ func ParseFactory[T ~uint8, Names []string | [][]string](names Names) Parser[T] // If no match is found, return an error. return unknown, fmt.Errorf("invalid %T value: %q", unknown, v) + case []byte: + switch len(v) { + case 0: + return unknown, nil + case 1: + return T(v[0]), nil + default: + panic(errors.New("byte slices should be parsed as strings: this code should be unreachable")) + } case T: if v >= T(len(normalizedNames)) { return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) diff --git a/parse_test.go b/parse_test.go index 00da3ec..8ca4971 100644 --- a/parse_test.go +++ b/parse_test.go @@ -32,6 +32,22 @@ func TestParseFactory(t *testing.T) { {value: "REVIEW", expected: StatusReview}, {value: "PUBLISHED", expected: StatusPublished}, {value: "ARCHIVED", expected: StatusArchived}, + {value: byte(0), expected: StatusUnknown}, + {value: byte(1), expected: StatusDraft}, + {value: byte(2), expected: StatusReview}, + {value: byte(3), expected: StatusPublished}, + {value: byte(4), expected: StatusArchived}, + {value: []byte{0}, expected: StatusUnknown}, + {value: []byte{1}, expected: StatusDraft}, + {value: []byte{2}, expected: StatusReview}, + {value: []byte{3}, expected: StatusPublished}, + {value: []byte{4}, expected: StatusArchived}, + {value: []byte(""), expected: StatusUnknown}, + {value: []byte("unknown"), expected: StatusUnknown}, + {value: []byte("draft"), expected: StatusDraft}, + {value: []byte("review"), expected: StatusReview}, + {value: []byte("published"), expected: StatusPublished}, + {value: []byte("archived"), expected: StatusArchived}, {value: uint(0), expected: StatusUnknown}, {value: uint(1), expected: StatusDraft}, {value: uint(2), expected: StatusReview},