@@ -5,6 +5,12 @@ import (
55 "database/sql/driver"
66 "encoding/json"
77 "fmt"
8+ "go/types"
9+ "path/filepath"
10+ "strings"
11+
12+ g "github.com/dave/jennifer/jen"
13+ "golang.org/x/tools/go/packages"
814)
915
1016type Enum interface {
@@ -35,11 +41,112 @@ func GenerateComment() string {
3541 return fmt .Sprintf ("Code generated by enumify v%s. DO NOT EDIT." , Version ())
3642}
3743
38- func Generate (opts Options ) error {
39- fmt .Printf ("%+v\n " , opts )
44+ func Generate (opts Options ) (err error ) {
45+ f := g .NewFile (opts .Pkg )
46+ f .PackageComment (GenerateComment ())
47+
48+ if _ , err = opts .discover (); err != nil {
49+ return err
50+ }
51+
52+ if err = f .Save (opts .fileName ()); err != nil {
53+ return err
54+ }
4055 return nil
4156}
4257
43- func GenerateTests (opts Options ) error {
58+ func GenerateTests (opts Options ) (err error ) {
59+ f := g .NewFile (opts .Pkg + "_test" )
60+ f .PackageComment (GenerateComment ())
61+
62+ if err = f .Save (opts .testFileName ()); err != nil {
63+ return err
64+ }
4465 return nil
4566}
67+
68+ func (o Options ) fileName () string {
69+ ext := filepath .Ext (o .File )
70+ return strings .TrimSuffix (filepath .Base (o .File ), ext ) + "_gen" + ext
71+ }
72+
73+ func (o Options ) testFileName () string {
74+ ext := filepath .Ext (o .File )
75+ return strings .TrimSuffix (filepath .Base (o .File ), ext ) + "_gen_test" + ext
76+ }
77+
78+ func (o * Options ) discover () (etypes EnumTypes , err error ) {
79+ // Build tool package discovery configuration.
80+ cfg := & packages.Config {
81+ Mode : packages .NeedTypes | packages .NeedTypesInfo ,
82+ }
83+
84+ // NOTE: do not use the driver query file={os.File} here because it will load the
85+ // entire package instead of just the contents of the file. As a result, the
86+ // types discovered by packages.Load will have the package "command-line-arguments"
87+ // scope rather than the scope of the package being inspected.
88+ //
89+ // We prefer this so we can isolate the specific files that have a go generate
90+ // directive and ignore the other files including other files that may also have
91+ // go generate directives.
92+ //
93+ // TODO: what if multiple enums are defined in the same file?
94+ var pkgs []* packages.Package
95+ if pkgs , err = packages .Load (cfg , o .File ); err != nil {
96+ return nil , fmt .Errorf ("failed to load package %q for inspection: %w" , o .File , err )
97+ }
98+
99+ if len (pkgs ) == 0 {
100+ return nil , fmt .Errorf ("no packages found for inspection" )
101+ }
102+
103+ if len (pkgs ) > 1 {
104+ return nil , fmt .Errorf ("multiple packages found for inspection: %v" , pkgs )
105+ }
106+
107+ gopkg := pkgs [0 ]
108+ if len (gopkg .Errors ) > 0 {
109+ return nil , fmt .Errorf ("package errors: %v" , pkgs [0 ].Errors )
110+ }
111+
112+ // Get the predeclared uint8 type for comparison
113+ uint8Type := types .Typ [types .Uint8 ]
114+
115+ // First pass: discover all the enum types in the package.
116+ // An enum type is a type whose underlying type is uint8.
117+ etypes = make (EnumTypes , 0 , 1 )
118+ scope := gopkg .Types .Scope ()
119+ for _ , name := range scope .Names () {
120+ obj := scope .Lookup (name )
121+ if typ , ok := obj .(* types.TypeName ); ok {
122+ if types .Identical (typ .Type ().Underlying (), uint8Type ) {
123+ etypes = append (etypes , & EnumType {
124+ Name : name ,
125+ Type : obj .Type (),
126+ gopkg : gopkg ,
127+ scope : scope ,
128+ })
129+ }
130+ }
131+ }
132+
133+ // Second pass: populate the enum types with consts and the name variable.
134+ for _ , etype := range etypes {
135+ // Discover the consts and names variable for the enum type.
136+ etype .discover ()
137+
138+ // If the names variable is not set, but one was passed in from the command
139+ // line, then attempt to set it on the enum type.
140+ if etype .NamesVar == nil && o .NameVar != "" {
141+ if err = etype .setNamesVar (o .NameVar ); err != nil {
142+ return nil , fmt .Errorf ("failed to set names variable for enum type %q: %w" , etype .Name , err )
143+ }
144+ }
145+
146+ // The enum type must be valid before we can generate code for it.
147+ if err = etype .validate (); err != nil {
148+ return nil , err
149+ }
150+ }
151+ return etypes , nil
152+ }
0 commit comments