From 862b2c19893fe089303deb39c5fb05b28ea351bf Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Tue, 7 Apr 2026 14:25:06 -0500 Subject: [PATCH 1/2] [FEAT] Code Generation, Finally! --- cmd/enumify/main.go | 2 + enumify.go | 232 +++++++++++++++++++++++++++++++++---------- example/calendar.go | 4 +- export_test.go | 9 ++ strings.go | 31 ++++++ strings_test.go | 49 +++++++++ testdata/namevars.go | 56 +++++++++++ testing.go | 2 + typedef.go | 69 ++++++++++++- typedef_test.go | 67 +++++++++++++ 10 files changed, 464 insertions(+), 57 deletions(-) create mode 100644 export_test.go create mode 100644 strings.go create mode 100644 strings_test.go create mode 100644 testdata/namevars.go create mode 100644 typedef_test.go diff --git a/cmd/enumify/main.go b/cmd/enumify/main.go index 98c364d..25f5139 100644 --- a/cmd/enumify/main.go +++ b/cmd/enumify/main.go @@ -18,6 +18,8 @@ func main() { // Bind flag variables directly to the opts struct flag.StringVar(&opts.NameVar, "names", "", "variable name that contains the string reprs of the enum values") + flag.BoolVar(&opts.CaseSensitive, "case-sensitive", false, "make the enum case sensitive") + flag.BoolVar(&opts.SpaceSensitive, "space-sensitive", false, "make the enum space sensitive") flag.BoolVar(&opts.NoTests, "no-tests", false, "skip testing code generation") flag.BoolVar(&opts.NoParser, "no-parser", false, "skip parser code generation") flag.BoolVar(&opts.NoStringer, "no-stringer", false, "skip Stringer interface code generation") diff --git a/enumify.go b/enumify.go index 3a7256e..c276dc9 100644 --- a/enumify.go +++ b/enumify.go @@ -24,58 +24,27 @@ type Enum interface { // Options defines how the code generation should be performed. These values are set by // the CLI flags passed via the go generate directive. type Options struct { - File string // The file that kicked off the generation (from $GOFILE) - Pkg string // The package that the enum is in (from $GOPACKAGE) - NameVar string // The variable name that contains the names of the enum values - NoTests bool // Whether to skip the testing code generation (defaults to false) - NoParser bool // Whether to skip the parser code generation (defaults to false) - NoStringer bool // Whether to skip the Stringer interface code generation (defaults to false) - NoText bool // Whether to skip the text interfaces code generation (defaults to false) - NoBinary bool // Whether to skip the binary interfaces code generation (defaults to false) - NoJSON bool // Whether to skip the JSON interfaces code generation (defaults to false) - NoYAML bool // Whether to skip the YAML interfaces code generation (defaults to false) - NoSQL bool // Whether to skip the SQL interfaces code generation (defaults to false) + File string // The file that kicked off the generation (from $GOFILE) + Pkg string // The package that the enum is in (from $GOPACKAGE) + NameVar string // The variable name that contains the names of the enum values + CaseSensitive bool // Whether to make the enum case sensitive (defaults to false) + SpaceSensitive bool // Whether to make the enum space sensitive (defaults to false) + NoTests bool // Whether to skip the testing code generation (defaults to false) + NoParser bool // Whether to skip the parser code generation (defaults to false) + NoStringer bool // Whether to skip the Stringer interface code generation (defaults to false) + NoText bool // Whether to skip the text interfaces code generation (defaults to false) + NoBinary bool // Whether to skip the binary interfaces code generation (defaults to false) + NoJSON bool // Whether to skip the JSON interfaces code generation (defaults to false) + NoYAML bool // Whether to skip the YAML interfaces code generation (defaults to false) + NoSQL bool // Whether to skip the SQL interfaces code generation (defaults to false) } -func GenerateComment() string { - return fmt.Sprintf("Code generated by enumify v%s. DO NOT EDIT.", Version()) -} - -func Generate(opts Options) (err error) { - f := g.NewFile(opts.Pkg) - f.PackageComment(GenerateComment()) - - if _, err = opts.discover(); err != nil { - return err - } - - if err = f.Save(opts.fileName()); err != nil { - return err - } - return nil -} - -func GenerateTests(opts Options) (err error) { - f := g.NewFile(opts.Pkg + "_test") - f.PackageComment(GenerateComment()) - - if err = f.Save(opts.testFileName()); err != nil { - return err - } - return nil -} - -func (o Options) fileName() string { - ext := filepath.Ext(o.File) - return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen" + ext -} - -func (o Options) testFileName() string { - ext := filepath.Ext(o.File) - return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen_test" + ext -} - -func (o *Options) discover() (etypes EnumTypes, err error) { +// Discover is the entry point for all code generation. It will parse the file specified +// by the go generate directive into an AST, then loop through all of the data in the +// file to discover the Enum types (anything that extends uint8), the constant values +// for the enum types, and the names variable based on the name pattern, the passed in +// name, and the allowed name var types. +func Discover(o Options) (etypes EnumTypes, err error) { // Build tool package discovery configuration. cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedTypesInfo, @@ -150,3 +119,166 @@ func (o *Options) discover() (etypes EnumTypes, err error) { } return etypes, nil } + +//============================================================================ +// Code Generation Entry Points +//============================================================================ + +// The generation comment is considered best practice when writing golang code generators. +// In addition to the standard warning, it also includes the version of the generator that +// was used to generate the code for debugging purposes. +func GenerateComment() string { + return fmt.Sprintf("Code generated by enumify v%s. DO NOT EDIT.", Version()) +} + +// Generate the code file for the enum types. +func Generate(opts Options) (err error) { + f := g.NewFile(opts.Pkg) + f.PackageComment(GenerateComment()) + + // Manage import names + f.ImportName("go.rtnl.ai/enumify", "enumify") + f.ImportName("encoding/json", "json") + + var etypes EnumTypes + if etypes, err = Discover(opts); err != nil { + return err + } + + // Write the code for the enum types. + for _, etype := range etypes { + typeID := etype.Id() + namesVar := etype.NamesVarId() + + f.Comment("//============================================================================") + f.Comment("// " + etype.Name + " Enum Type: Generated Functions and Methods") + f.Comment("//============================================================================") + f.Line() + + // Write the parser code for the enum type + if !opts.NoParser { + // TODO: this uses the factory function to create the parser function. This + // is not ideal because it creates a dependency on the go.rtnl.ai/enumify + // package. See https://github.com/rotationalio/enumify/issues/7 for + // ideas on how to improve this. + parserVar := g.Id(LowerFirst(etype.Name) + "Parser") + parserType := g.Func().Params(g.Any()).Params(typeID, g.Error()) + f.Var().Add(parserVar, parserType).Op("=").Qual("go.rtnl.ai/enumify", "ParseFactory").Types(typeID).Call(namesVar) + f.Line() + + f.Commentf("Parse%s parses the given value into a %s.", etype.Name, etype.Name) + parserFunc := f.Func().Id("Parse" + UpperFirst(etype.Name)) + parserFunc.Params(g.Id("s").Any()).Params(typeID, g.Error()).Block( + g.Return(parserVar.Clone().Call(g.Id("s"))), + ) + f.Line() + } + + s := g.Id("s") + methodSig := g.Func().Params(s.Clone().Add(typeID)) + methodPtrSig := g.Func().Params(s.Clone().Op("*").Add(typeID)) + + // Write the stringer code for the enum type + if !opts.NoStringer { + f.Commentf("Ensure %s implements fmt.Stringer.", etype.Name) + + method := methodSig.Clone().Id("String").Call().String() + method.Block( + g.If(s.Clone().Op(">=").Add(typeID).Call(g.Len(namesVar))).Block( + g.Return(etype.IndexNames(etype.ZeroConstId())), + ), + g.Return(etype.IndexNames(s)), + ) + f.Add(method) + f.Line() + } + + // JSON Marshal and Unmarshal code + if !opts.NoJSON { + f.Commentf("Ensure %s implements json.Marshaler.", etype.Name) + method := methodSig.Clone().Id("MarshalJSON").Call().Params(g.Id("[]byte"), g.Error()) + method.Block( + g.Return(g.Qual("encoding/json", "Marshal").Call(s.Clone().Dot("String").Call())), + ) + f.Add(method) + f.Line() + + f.Commentf("Ensure %s implements json.Unmarshaler.", etype.Name) + method = methodPtrSig.Clone().Id("UnmarshalJSON").Call(g.Id("data").Id("[]byte")).Params(g.Id("err").Error()) + method.Block( + g.Var().Id("v").Any(), + g.If( + g.Id("err").Op("=").Qual("encoding/json", "Unmarshal").Call(g.Id("data"), g.Op("&").Id("v")).Op(";").Id("err").Op("!=").Nil()). + Block( + g.Return(g.Id("err")), + ), + g.Line(), + g.If( + g.Op("*").Add(s).Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(g.Id("v")).Op(";").Id("err").Op("!=").Nil(). + Block( + g.Return(g.Id("err")), + ), + ), + g.Return().Nil(), + ) + f.Add(method) + f.Line() + } + + } + + if err = f.Save(opts.fileName()); err != nil { + return err + } + return nil +} + +// Generate the test file for the enum types. +func GenerateTests(opts Options) (err error) { + f := g.NewFile(opts.Pkg) + f.PackageComment(GenerateComment()) + + // Manage import names + f.ImportName("testing", "testing") + f.ImportName("go.rtnl.ai/enumify", "enumify") + + // Discover the enum types in the package. + var etypes EnumTypes + if etypes, err = Discover(opts); err != nil { + return err + } + + // For each enum type, write a test function that creates and executes an enumify + // test suite for the enum type. + for _, etype := range etypes { + f.Func().Id("Test"+UpperFirst(etype.Name)).Add(TestingT).Block( + g.Id("suite").Op(":=").Qual("go.rtnl.ai/enumify", "TestSuite").Types(etype.Id(), etype.NamesVarTypeId()).Block( + g.Id("Values").Op(":").Add(etype.ConstLiteral()).Op(","), + g.Id("Names").Op(":").Add(etype.NamesVarId()).Op(","), + g.Id("ICase").Op(":").Lit(!opts.CaseSensitive).Op(","), + g.Id("ISpace").Op(":").Lit(!opts.SpaceSensitive).Op(","), + ), + g.Line(), + g.Id("suite").Dot("Run").Call(g.Id("t")), + ).Line() + } + + if err = f.Save(opts.testFileName()); err != nil { + return err + } + return nil +} + +//============================================================================ +// Options Methods +//============================================================================ + +func (o Options) fileName() string { + ext := filepath.Ext(o.File) + return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen" + ext +} + +func (o Options) testFileName() string { + ext := filepath.Ext(o.File) + return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen_test" + ext +} diff --git a/example/calendar.go b/example/calendar.go index 1772629..4ed5ece 100644 --- a/example/calendar.go +++ b/example/calendar.go @@ -26,7 +26,7 @@ const ( // and because it matches the dayNames pattern to connect it with the Day enum. // //lint:ignore U1000 this is used by the enumify generator -var dayNames = [...]string{ +var dayNames = []string{ "unknown", "Monday", "Tuesday", @@ -61,7 +61,7 @@ const ( // having to specify it using the -names flag. // //lint:ignore U1000 this is used by the enumify generator -var monthNames = [2][13]string{ +var monthNames = [][]string{ {"", "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"}, {"", "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}, } diff --git a/export_test.go b/export_test.go new file mode 100644 index 0000000..57b4c4b --- /dev/null +++ b/export_test.go @@ -0,0 +1,9 @@ +package enumify + +var ( + IsStringArray = isStringArray + IsStringSlice = isStringSlice + IsStringTable = isStringTable + IsStrings2DArray = isStrings2DArray + IsNamesType = isNamesType +) diff --git a/strings.go b/strings.go new file mode 100644 index 0000000..3ad2777 --- /dev/null +++ b/strings.go @@ -0,0 +1,31 @@ +package enumify + +import ( + "unicode" + "unicode/utf8" +) + +func LowerFirst(s string) string { + if s == "" { + return "" + } + + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError { + return s + } + + return string(unicode.ToLower(r)) + s[size:] +} + +func UpperFirst(s string) string { + if s == "" { + return "" + } + + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError { + return s + } + return string(unicode.ToUpper(r)) + s[size:] +} diff --git a/strings_test.go b/strings_test.go new file mode 100644 index 0000000..bf2ff8e --- /dev/null +++ b/strings_test.go @@ -0,0 +1,49 @@ +package enumify_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.rtnl.ai/enumify" +) + +func TestLowerFirst(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {input: "", expected: ""}, + {input: "foo", expected: "foo"}, + {input: "Foo", expected: "foo"}, + {input: "FOO", expected: "fOO"}, + {input: "FooBar", expected: "fooBar"}, + {input: "ΩTime", expected: "ωTime"}, + {input: "ωTime", expected: "ωTime"}, + } + + for _, test := range tests { + actual := enumify.LowerFirst(test.input) + require.Equal(t, test.expected, actual) + } +} + +func TestUpperFirst(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {input: "", expected: ""}, + {input: "foo", expected: "Foo"}, + {input: "Foo", expected: "Foo"}, + {input: "FOO", expected: "FOO"}, + {input: "FooBar", expected: "FooBar"}, + {input: "fooBar", expected: "FooBar"}, + {input: "ΩTime", expected: "ΩTime"}, + {input: "ωTime", expected: "ΩTime"}, + } + + for _, test := range tests { + actual := enumify.UpperFirst(test.input) + require.Equal(t, test.expected, actual) + } +} diff --git a/testdata/namevars.go b/testdata/namevars.go new file mode 100644 index 0000000..85633c1 --- /dev/null +++ b/testdata/namevars.go @@ -0,0 +1,56 @@ +package testdata + +var namesArray = [...]string{ + "foo", + "bar", + "baz", +} + +var namesSlice = []string{ + "foo", + "bar", + "baz", +} + +var namesTable = [][]string{ + {"foo", "bar", "baz"}, + {"foo", "bar", "baz"}, +} + +var names2DArray = [2][3]string{ + {"foo", "bar", "baz"}, + {"foo", "bar", "baz"}, +} + +var notNamesArray = [...]int{ + 1, + 2, + 3, +} + +var notNamesSlice = []int{ + 1, + 2, + 3, +} + +var notNamesTable = [][]int{ + {1, 2, 3}, + {1, 2, 3}, +} + +var notNames2DArray = [2][3]int{ + {1, 2, 3}, + {1, 2, 3}, +} + +const ( + Foo int32 = iota + Bar + Baz +) + +var ( + debug bool = false + progName string = "testing" +) diff --git a/testing.go b/testing.go index 997dc03..a3bd0bd 100644 --- a/testing.go +++ b/testing.go @@ -8,11 +8,13 @@ import ( "testing" "unicode" + g "github.com/dave/jennifer/jen" "github.com/stretchr/testify/require" ) var ( DefaultInvalid = []any{"foo", "123", "INVALID", 257, -1, 314.314, struct{}{}, true, false} + TestingT = g.Params(g.Id("t").Op("*").Qual("testing", "T")) ) const ( diff --git a/typedef.go b/typedef.go index 29c1812..fac1a4e 100644 --- a/typedef.go +++ b/typedef.go @@ -4,8 +4,8 @@ import ( "fmt" "go/constant" "go/types" - "unicode" + g "github.com/dave/jennifer/jen" "golang.org/x/tools/go/packages" ) @@ -20,8 +20,63 @@ type EnumType struct { scope *types.Scope } +//============================================================================ +// EnumType Code Generation Utilities +//============================================================================ + +func (e *EnumType) Id() *g.Statement { + return g.Id(e.Name) +} + +func (e *EnumType) NamesVarId() *g.Statement { + return g.Id(e.NamesVar.Name()) +} + +func (e *EnumType) NamesVarTypeId() *g.Statement { + switch { + case isStringTable(e.NamesVar.Type()): + return g.Id("[][]string") + case isStringSlice(e.NamesVar.Type()): + return g.Id("[]string") + default: + panic(fmt.Errorf("unsupported names variable type: %T", e.NamesVar.Type())) + } +} + +func (e *EnumType) IndexNames(i g.Code) *g.Statement { + switch { + case isStringTable(e.NamesVar.Type()): + return e.NamesVarId().Index(g.Lit(0)).Index(i) + case isStringSlice(e.NamesVar.Type()): + return e.NamesVarId().Index(i) + default: + panic(fmt.Errorf("unsupported names variable type: %T", e.NamesVar.Type())) + + } +} + +func (e *EnumType) ConstLiteral() *g.Statement { + constLits := make([]g.Code, 0, len(e.Consts)) + for _, constObj := range e.Consts { + constLit := g.Id(constObj.Name()) + constLits = append(constLits, constLit) + } + return g.Id("[]" + e.Name).Add(g.Values(constLits...)) +} + +func (e *EnumType) ZeroConstId() *g.Statement { + if zeroConst := e.zeroConst(); zeroConst != nil { + return g.Id(zeroConst.Name()) + } + return g.Lit(0) +} + +//============================================================================ +// EnumType AST Validation and Parsing +//============================================================================ + func (e *EnumType) validate() error { - // Not sure how this would happen ... but we're going to LBYL becuase types are hard. + // Not sure how this would happen ... but we're going to LBYL because types are hard. if e.Name == "" || e.Type == nil { return fmt.Errorf("enum type %q has no name or type", e.Name) } @@ -133,12 +188,16 @@ func (e *EnumType) setNamesVar(name string) error { } func (e *EnumType) namesMatch(name string) bool { - chr := []rune(e.Name)[0] - ucc := string(unicode.ToUpper(chr)) + e.Name[1:] + "Names" - lcc := string(unicode.ToLower(chr)) + e.Name[1:] + "Names" + target := e.Name + "Names" + ucc := UpperFirst(target) + lcc := LowerFirst(target) return name == ucc || name == lcc } +//============================================================================ +// Helper Functions +//============================================================================ + func isNamesType(typ types.Type) bool { return isStringArray(typ) || isStringSlice(typ) || isStringTable(typ) || isStrings2DArray(typ) } diff --git a/typedef_test.go b/typedef_test.go new file mode 100644 index 0000000..6f88ccc --- /dev/null +++ b/typedef_test.go @@ -0,0 +1,67 @@ +package enumify_test + +import ( + "fmt" + "go/types" + "testing" + + "github.com/stretchr/testify/require" + . "go.rtnl.ai/enumify" + "golang.org/x/tools/go/packages" +) + +type NameVarTypeCheck func(typ types.Type) bool + +func TestNameVarTypeChecks(t *testing.T) { + testNameVarTypeCheck := func(nvtc NameVarTypeCheck, valid string) func(t *testing.T) { + return func(t *testing.T) { + testCases, err := ParseTypes("testdata/namevars.go") + require.NoError(t, err, "could not parse types from testdata/namevars.go") + require.GreaterOrEqual(t, len(testCases), 1, "expected at least 1 test case") + + validated := false + for _, obj := range testCases { + if obj.Name() == valid { + require.True(t, IsNamesType(obj.Type()), "expected %q to be a names type", obj.Name()) + require.True(t, nvtc(obj.Type()), "expected %q to pass type check", obj.Name()) + validated = true + } else { + require.False(t, nvtc(obj.Type()), "expected %q to not be a valid type", obj.Name()) + } + } + require.True(t, validated, "no valid example for type check found") + } + } + + t.Run("StringArray", testNameVarTypeCheck(IsStringArray, "namesArray")) + t.Run("StringSlice", testNameVarTypeCheck(IsStringSlice, "namesSlice")) + t.Run("StringTable", testNameVarTypeCheck(IsStringTable, "namesTable")) + t.Run("Strings2DArray", testNameVarTypeCheck(IsStrings2DArray, "names2DArray")) +} + +func ParseTypes(file string) (out []types.Object, err error) { + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedTypesInfo, + } + + var pkgs []*packages.Package + if pkgs, err = packages.Load(cfg, file); err != nil { + return nil, fmt.Errorf("could not load package from %q: %w", file, err) + } + + if len(pkgs) != 1 { + return nil, fmt.Errorf("expected only 1 package returned from packages.Load, got %d", len(pkgs)) + } + + if len(pkgs[0].Errors) > 0 { + return nil, fmt.Errorf("package errors: %v", pkgs[0].Errors) + } + + names := pkgs[0].Types.Scope().Names() + out = make([]types.Object, 0, len(names)) + for _, name := range names { + obj := pkgs[0].Types.Scope().Lookup(name) + out = append(out, obj) + } + return out, nil +} From 2c4ad519907c1aefa183b0ab5fb875064ff65bf6 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Tue, 7 Apr 2026 18:36:38 -0500 Subject: [PATCH 2/2] finalize base code gen --- enumify.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++------- typedef.go | 50 +++++++++++++++++++++++++--- 2 files changed, 129 insertions(+), 16 deletions(-) diff --git a/enumify.go b/enumify.go index c276dc9..5b4eb93 100644 --- a/enumify.go +++ b/enumify.go @@ -139,6 +139,7 @@ func Generate(opts Options) (err error) { // Manage import names f.ImportName("go.rtnl.ai/enumify", "enumify") f.ImportName("encoding/json", "json") + f.ImportName("database/sql/driver", "driver") var etypes EnumTypes if etypes, err = Discover(opts); err != nil { @@ -147,8 +148,12 @@ func Generate(opts Options) (err error) { // Write the code for the enum types. for _, etype := range etypes { - typeID := etype.Id() - namesVar := etype.NamesVarId() + evar := g.Id("err") // err variable name + s := etype.ReceiverId() // e if Enum or s if Status etc. + ptrS := g.Op("*").Add(etype.ReceiverId()) // *e if Enum or *s if Status etc. + typeID := etype.Id() // Enum + namesVar := etype.NamesVarId() // enumNames + namesVarSlice := etype.NamesVarSliceId() // enumNames or enumNames[0] f.Comment("//============================================================================") f.Comment("// " + etype.Name + " Enum Type: Generated Functions and Methods") @@ -168,13 +173,12 @@ func Generate(opts Options) (err error) { f.Commentf("Parse%s parses the given value into a %s.", etype.Name, etype.Name) parserFunc := f.Func().Id("Parse" + UpperFirst(etype.Name)) - parserFunc.Params(g.Id("s").Any()).Params(typeID, g.Error()).Block( - g.Return(parserVar.Clone().Call(g.Id("s"))), + parserFunc.Params(s.Clone().Any()).Params(typeID, g.Error()).Block( + g.Return(parserVar.Clone().Call(s.Clone())), ) f.Line() } - s := g.Id("s") methodSig := g.Func().Params(s.Clone().Add(typeID)) methodPtrSig := g.Func().Params(s.Clone().Op("*").Add(typeID)) @@ -184,7 +188,7 @@ func Generate(opts Options) (err error) { method := methodSig.Clone().Id("String").Call().String() method.Block( - g.If(s.Clone().Op(">=").Add(typeID).Call(g.Len(namesVar))).Block( + g.If(s.Clone().Op(">=").Add(typeID).Call(g.Len(namesVarSlice))).Block( g.Return(etype.IndexNames(etype.ZeroConstId())), ), g.Return(etype.IndexNames(s)), @@ -204,19 +208,21 @@ func Generate(opts Options) (err error) { f.Line() f.Commentf("Ensure %s implements json.Unmarshaler.", etype.Name) - method = methodPtrSig.Clone().Id("UnmarshalJSON").Call(g.Id("data").Id("[]byte")).Params(g.Id("err").Error()) + + sv := g.Id("sv") + method = methodPtrSig.Clone().Id("UnmarshalJSON").Call(g.Id("data").Id("[]byte")).Params(evar.Clone().Error()) method.Block( - g.Var().Id("v").Any(), + g.Var().Add(sv).Any(), g.If( - g.Id("err").Op("=").Qual("encoding/json", "Unmarshal").Call(g.Id("data"), g.Op("&").Id("v")).Op(";").Id("err").Op("!=").Nil()). + evar.Clone().Op("=").Qual("encoding/json", "Unmarshal").Call(g.Id("data"), g.Op("&").Add(sv)).Op(";").Id("err").Op("!=").Nil()). Block( - g.Return(g.Id("err")), + g.Return(evar), ), g.Line(), g.If( - g.Op("*").Add(s).Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(g.Id("v")).Op(";").Id("err").Op("!=").Nil(). + ptrS.Clone().Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil(). Block( - g.Return(g.Id("err")), + g.Return(evar), ), ), g.Return().Nil(), @@ -225,6 +231,71 @@ func Generate(opts Options) (err error) { f.Line() } + // YAML Marshal and Unmarshal code + if !opts.NoYAML { + f.Commentf("Ensure %s implements yaml.Marshaler.", etype.Name) + method := methodSig.Clone().Id("MarshalYAML").Call().Params(g.Any(), g.Error()) + method.Block( + g.Return(s.Clone().Dot("String").Call(), g.Nil()), + ) + f.Add(method) + f.Line() + + f.Commentf("Ensure %s implements yaml.Unmarshaler.", etype.Name) + + sv := g.Id("sv") + method = methodPtrSig.Clone().Id("UnmarshalYAML").Call(g.Id("unmarshal").Func().Params(g.Any()).Params(g.Error())).Params(evar.Clone().Error()) + method.Block( + g.Var().Add(sv).String(), + g.If(evar.Clone().Op("=").Id("unmarshal").Call(g.Op("&").Add(sv)).Op(";").Id("err").Op("!=").Nil()).Block( + g.Return(evar), + ), + g.Line(), + g.If(g.Add(ptrS.Clone()).Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil(). + Block( + g.Return(evar), + ), + ), + g.Return(g.Nil()), + ) + f.Add(method) + f.Line() + } + + // SQL Scanner and Valuer code + if !opts.NoSQL { + f.Commentf("Ensure %s implements sql.Scanner.", etype.Name) + val := g.Id("val") + method := methodPtrSig.Clone().Id("Scan").Call(g.Id("src").Any()).Params(evar.Clone().Error()) + method.Block( + g.Switch(val.Clone().Op(":=").Id("src").Assert(g.Type())).Block( + g.Case(g.Nil()).Block(g.Return(g.Nil())), + g.Case(g.String()).Block( + g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(val)), + g.Return(evar), + ), + g.Case(g.Id("[]byte")).Block( + g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(g.String().Call(val))), + g.Return(evar), + ), + g.Default().Block( + g.Return(g.Qual("fmt", "Errorf").Call(g.Lit("cannot scan %T into "+etype.Name), val)), + ), + ), + ) + + f.Add(method) + f.Line() + + f.Commentf("Ensure %s implements driver.Valuer.", etype.Name) + method = methodSig.Clone().Id("Value").Call().Params(g.Qual("database/sql/driver", "Value"), g.Error()) + method.Block( + g.Return(s.Clone().Dot("String").Call(), g.Nil()), + ) + f.Add(method) + f.Line() + } + } if err = f.Save(opts.fileName()); err != nil { diff --git a/typedef.go b/typedef.go index fac1a4e..55c23c0 100644 --- a/typedef.go +++ b/typedef.go @@ -4,6 +4,9 @@ import ( "fmt" "go/constant" "go/types" + "slices" + "unicode" + "unicode/utf8" g "github.com/dave/jennifer/jen" "golang.org/x/tools/go/packages" @@ -28,6 +31,19 @@ func (e *EnumType) Id() *g.Statement { return g.Id(e.Name) } +func (e *EnumType) ReceiverId() *g.Statement { + if e.Name == "" { + return g.Id("e") + } + + r, _ := utf8.DecodeRuneInString(e.Name) + if r == utf8.RuneError { + return g.Id("e") + } + + return g.Id(string(unicode.ToLower(r))) +} + func (e *EnumType) NamesVarId() *g.Statement { return g.Id(e.NamesVar.Name()) } @@ -43,6 +59,17 @@ func (e *EnumType) NamesVarTypeId() *g.Statement { } } +func (e *EnumType) NamesVarSliceId() *g.Statement { + switch { + case isStringTable(e.NamesVar.Type()): + return e.NamesVarId().Index(g.Lit(0)) + case isStringSlice(e.NamesVar.Type()): + return e.NamesVarId() + default: + panic(fmt.Errorf("unsupported names variable type: %T", e.NamesVar.Type())) + } +} + func (e *EnumType) IndexNames(i g.Code) *g.Statement { switch { case isStringTable(e.NamesVar.Type()): @@ -127,10 +154,8 @@ func (e *EnumType) validate() error { func (e *EnumType) zeroConst() types.Object { for _, obj := range e.Consts { - if c, ok := obj.(*types.Const); ok { - if v, ok := constant.Uint64Val(c.Val()); ok && v == 0 { - return obj - } + if v, err := constValue(obj); err == nil && v == 0 { + return obj } } return nil @@ -163,6 +188,13 @@ func (e *EnumType) discover() { continue } } + + // Sort the consts by their values. + slices.SortFunc(e.Consts, func(i, j types.Object) int { + valI, _ := constValue(i) + valJ, _ := constValue(j) + return int(valI - valJ) + }) } func (e *EnumType) setNamesVar(name string) error { @@ -236,3 +268,13 @@ func isStrings2DArray(typ types.Type) bool { } return isStringArray(outer.Elem()) } + +func constValue(obj types.Object) (uint64, error) { + if c, ok := obj.(*types.Const); ok { + if v, ok := constant.Uint64Val(c.Val()); ok { + return v, nil + } + return 0, fmt.Errorf("const %q has no uint64 value", obj.Name()) + } + return 0, fmt.Errorf("%q is not a constant", obj.Name()) +}