diff --git a/.gitignore b/.gitignore index ae4211a..3bff00b 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ go.work.sum # Do not track generated example code in order to ensure tests pass. example/*_gen.go +example/*_gen_test.go diff --git a/enumify.go b/enumify.go index f5791f8..3a7256e 100644 --- a/enumify.go +++ b/enumify.go @@ -5,6 +5,12 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "go/types" + "path/filepath" + "strings" + + g "github.com/dave/jennifer/jen" + "golang.org/x/tools/go/packages" ) type Enum interface { @@ -35,11 +41,112 @@ func GenerateComment() string { return fmt.Sprintf("Code generated by enumify v%s. DO NOT EDIT.", Version()) } -func Generate(opts Options) error { - fmt.Printf("%+v\n", opts) +func Generate(opts Options) (err error) { + f := g.NewFile(opts.Pkg) + f.PackageComment(GenerateComment()) + + if _, err = opts.discover(); err != nil { + return err + } + + if err = f.Save(opts.fileName()); err != nil { + return err + } return nil } -func GenerateTests(opts Options) error { +func GenerateTests(opts Options) (err error) { + f := g.NewFile(opts.Pkg + "_test") + f.PackageComment(GenerateComment()) + + if err = f.Save(opts.testFileName()); err != nil { + return err + } return nil } + +func (o Options) fileName() string { + ext := filepath.Ext(o.File) + return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen" + ext +} + +func (o Options) testFileName() string { + ext := filepath.Ext(o.File) + return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen_test" + ext +} + +func (o *Options) discover() (etypes EnumTypes, err error) { + // Build tool package discovery configuration. + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedTypesInfo, + } + + // NOTE: do not use the driver query file={os.File} here because it will load the + // entire package instead of just the contents of the file. As a result, the + // types discovered by packages.Load will have the package "command-line-arguments" + // scope rather than the scope of the package being inspected. + // + // We prefer this so we can isolate the specific files that have a go generate + // directive and ignore the other files including other files that may also have + // go generate directives. + // + // TODO: what if multiple enums are defined in the same file? + var pkgs []*packages.Package + if pkgs, err = packages.Load(cfg, o.File); err != nil { + return nil, fmt.Errorf("failed to load package %q for inspection: %w", o.File, err) + } + + if len(pkgs) == 0 { + return nil, fmt.Errorf("no packages found for inspection") + } + + if len(pkgs) > 1 { + return nil, fmt.Errorf("multiple packages found for inspection: %v", pkgs) + } + + gopkg := pkgs[0] + if len(gopkg.Errors) > 0 { + return nil, fmt.Errorf("package errors: %v", pkgs[0].Errors) + } + + // Get the predeclared uint8 type for comparison + uint8Type := types.Typ[types.Uint8] + + // First pass: discover all the enum types in the package. + // An enum type is a type whose underlying type is uint8. + etypes = make(EnumTypes, 0, 1) + scope := gopkg.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if typ, ok := obj.(*types.TypeName); ok { + if types.Identical(typ.Type().Underlying(), uint8Type) { + etypes = append(etypes, &EnumType{ + Name: name, + Type: obj.Type(), + gopkg: gopkg, + scope: scope, + }) + } + } + } + + // Second pass: populate the enum types with consts and the name variable. + for _, etype := range etypes { + // Discover the consts and names variable for the enum type. + etype.discover() + + // If the names variable is not set, but one was passed in from the command + // line, then attempt to set it on the enum type. + if etype.NamesVar == nil && o.NameVar != "" { + if err = etype.setNamesVar(o.NameVar); err != nil { + return nil, fmt.Errorf("failed to set names variable for enum type %q: %w", etype.Name, err) + } + } + + // The enum type must be valid before we can generate code for it. + if err = etype.validate(); err != nil { + return nil, err + } + } + return etypes, nil +} diff --git a/example/calendar.go b/example/calendar.go new file mode 100644 index 0000000..1772629 --- /dev/null +++ b/example/calendar.go @@ -0,0 +1,75 @@ +package example + +// Day is an enum type that should be implemented by the enumify generator. +// It uses a 1D array of strings for the names, which should be discovered by the +// enumify generator due to the go:generate directive above the Day type declaration. +// +//go:generate go run ../cmd/enumify +type Day uint8 + +// Constants for the Day enum values. +// These values should be discovered by the enumify generator since they use the same +// type as the Day enum, which is the type being generated. +const ( + Unknown Day = iota + Monday + Tuesday + Wednesday + Thursday + Friday + Saturday + Sunday +) + +// 1D array of strings for the names of the Day enum values. +// This should be discovered by the enumify generator due to the go:generate directive +// and because it matches the dayNames pattern to connect it with the Day enum. +// +//lint:ignore U1000 this is used by the enumify generator +var dayNames = [...]string{ + "unknown", + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +} + +// This enum should be discovered by the enumify generator due to the go:generate +// directive on line 7 of this file and because it matches the Enum spec pattern. +type Month uint8 + +const ( + Monthless Month = iota + January + February + March + April + May + June + July + August + September + October + November + December +) + +// This 2D array of strings should match the names pattern for the color enum without +// having to specify it using the -names flag. +// +//lint:ignore U1000 this is used by the enumify generator +var monthNames = [2][13]string{ + {"", "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"}, + {"", "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}, +} + +// This additional enum method should be ignored by the enumify generator. +func (m Month) Abbreviation() string { + if m >= Month(len(monthNames[1])) { + return monthNames[1][Monthless] + } + return monthNames[1][m] +} diff --git a/example/days.go b/example/days.go deleted file mode 100644 index 74c0464..0000000 --- a/example/days.go +++ /dev/null @@ -1,38 +0,0 @@ -package example - -// Day is an enum type that should be implemented by the enumify generator. -// It uses a 1D array of strings for the names, which should be discovered by the -// enumify generator due to the go:generate directive above the Day type declaration. -// -//go:generate go run ../cmd/enumify -type Day uint8 - -// Constants for the Day enum values. -// These values should be discovered by the enumify generator since they use the same -// type as the Day enum, which is the type being generated. -const ( - Unknown Day = iota - Monday - Tuesday - Wednesday - Thursday - Friday - Saturday - Sunday -) - -// 1D array of strings for the names of the Day enum values. -// This should be discovered by the enumify generator due to the go:generate directive -// and because it matches the dayNames pattern to connect it with the Day enum. -// -//lint:ignore U1000 this is used by the enumify generator -var dayNames = []string{ - "unknown", - "Monday", - "Tuesday", - "Wednesday", - "Thursday", - "Friday", - "Saturday", - "Sunday", -} diff --git a/example/status.go b/example/status.go index a994df3..9f4c749 100644 --- a/example/status.go +++ b/example/status.go @@ -28,3 +28,6 @@ var statusTable = [][]string{ {"unknown", "pending", "running", "failed", "success", "cancelled"}, {"text-secondary", "text-info", "text-primary", "text-danger", "text-success", "text-warning"}, } + +// This is an unrelated type that should be ignored by the enumify generator. +type Foo struct{} diff --git a/go.mod b/go.mod index c9c2501..2f69b70 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,16 @@ module go.rtnl.ai/enumify go 1.26.1 require ( + github.com/dave/jennifer v1.7.1 github.com/stretchr/testify v1.11.1 go.rtnl.ai/x v1.15.0 + golang.org/x/tools v0.43.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sync v0.20.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ec9d4c3..4d63bab 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,21 @@ +github.com/dave/jennifer v1.7.1 h1:B4jJJDHelWcDhlRQxWeo0Npa/pYKBLrirAQoTN45txo= +github.com/dave/jennifer v1.7.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.rtnl.ai/x v1.15.0 h1:tzMqlAXrwZ4CHNscAawlBbMjDvEwZxSu9AMxJB4CPOs= go.rtnl.ai/x v1.15.0/go.mod h1:ciQ9PaXDtZDznzBrGDBV2yTElKX3aJgtQfi6V8613bo= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/typedef.go b/typedef.go new file mode 100644 index 0000000..29c1812 --- /dev/null +++ b/typedef.go @@ -0,0 +1,179 @@ +package enumify + +import ( + "fmt" + "go/constant" + "go/types" + "unicode" + + "golang.org/x/tools/go/packages" +) + +type EnumTypes []*EnumType + +type EnumType struct { + Name string + Type types.Type + Consts []types.Object + NamesVar types.Object + gopkg *packages.Package + scope *types.Scope +} + +func (e *EnumType) validate() error { + // Not sure how this would happen ... but we're going to LBYL becuase types are hard. + if e.Name == "" || e.Type == nil { + return fmt.Errorf("enum type %q has no name or type", e.Name) + } + + // Need to have at least one constant declared for the enum type. + nConsts := len(e.Consts) + if nConsts == 0 { + return fmt.Errorf("enum type %q has no constants", e.Name) + } + + // Ensure all discovered constants are actually constants. + for _, obj := range e.Consts { + if _, ok := obj.(*types.Const); !ok { + return fmt.Errorf("enum type %q has a non-constant value: %T", e.Name, obj) + } + } + + // Need to have a names variable declared for the enum type. + if e.NamesVar == nil { + return fmt.Errorf("enum type %q has no names variable", e.Name) + } + + // There must be a zero-valued constant declared for the enum type. + if e.zeroConst() == nil { + return fmt.Errorf("enum type %q has no zero-valued constant", e.Name) + } + + // The number of constants must match the number of names in the names variable. + if namesVar, ok := e.NamesVar.(*types.Var); ok { + if namesVar.Kind() != types.PackageVar { + return fmt.Errorf("enum type %q has a non-package variable names: %T", e.Name, e.NamesVar) + } + + if !isNamesType(namesVar.Type()) { + return fmt.Errorf("enum type %q has a names variable that is not a string slice or string table: %T", e.Name, e.NamesVar.Type().Underlying()) + } + + // TODO: use the packages.NeedsSyntax mode to get the ast and find the + // CompositeLit that represents the names variable. You can then use its Elts + // property to get the number of elements and compare that to the number of + // constants. Unfortunately, there is no mapping from a types.Object to an + // ast.Node to do checking if it is an ast.CompositeLit so ast traversal is + // required, which is a bit too far for me to implement right now. + } + + return nil +} + +func (e *EnumType) zeroConst() types.Object { + for _, obj := range e.Consts { + if c, ok := obj.(*types.Const); ok { + if v, ok := constant.Uint64Val(c.Val()); ok && v == 0 { + return obj + } + } + } + return nil +} + +// Attempts automatic discovery of the enum constants and 1D or 2D string slice variable +// that contains the string representations of the enum values. +func (e *EnumType) discover() { + e.Consts = make([]types.Object, 0, 1) + e.NamesVar = nil + + for _, name := range e.scope.Names() { + // Constants are the same type as the enum type. + obj := e.scope.Lookup(name) + if types.Identical(obj.Type(), e.Type) { + // Only add constants (skipping at the very least, the declared type) + if _, ok := obj.(*types.Const); !ok { + continue + } + + e.Consts = append(e.Consts, obj) + continue + } + + // Name match means that the variable name is [name]Names. + // If the name matches this pattern and it is a string slice or string table, + // then it is set as the names variable. + if e.namesMatch(name) && isNamesType(obj.Type()) { + e.NamesVar = obj + continue + } + } +} + +func (e *EnumType) setNamesVar(name string) error { + if e.NamesVar != nil { + if e.NamesVar.Name() != name { + return fmt.Errorf("already set or discovered names variable %q, cannot set to %q", e.NamesVar.Name(), name) + } + return nil + } + + var obj types.Object + if obj = e.scope.Lookup(name); obj == nil { + return fmt.Errorf("names variable %q not found", name) + } + + if !isNamesType(obj.Type()) { + return fmt.Errorf("names variable %q is not a string slice, string array, or string table", name) + } + + // Success! Set the names variable and return. + e.NamesVar = obj + return nil +} + +func (e *EnumType) namesMatch(name string) bool { + chr := []rune(e.Name)[0] + ucc := string(unicode.ToUpper(chr)) + e.Name[1:] + "Names" + lcc := string(unicode.ToLower(chr)) + e.Name[1:] + "Names" + return name == ucc || name == lcc +} + +func isNamesType(typ types.Type) bool { + return isStringArray(typ) || isStringSlice(typ) || isStringTable(typ) || isStrings2DArray(typ) +} + +func isStringArray(typ types.Type) bool { + array, ok := typ.Underlying().(*types.Array) + if !ok { + return false + } + basic, ok := array.Elem().Underlying().(*types.Basic) + return ok && basic.Kind() == types.String +} + +func isStringSlice(typ types.Type) bool { + slice, ok := typ.Underlying().(*types.Slice) + if !ok { + return false + } + + basic, ok := slice.Elem().Underlying().(*types.Basic) + return ok && basic.Kind() == types.String +} + +func isStringTable(typ types.Type) bool { + outer, ok := typ.Underlying().(*types.Slice) + if !ok { + return false + } + return isStringSlice(outer.Elem()) +} + +func isStrings2DArray(typ types.Type) bool { + outer, ok := typ.Underlying().(*types.Array) + if !ok { + return false + } + return isStringArray(outer.Elem()) +}