Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/enumify/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ func main() {

// Bind flag variables directly to the opts struct
flag.StringVar(&opts.NameVar, "names", "", "variable name that contains the string reprs of the enum values")
flag.BoolVar(&opts.CaseSensitive, "case-sensitive", false, "make the enum case sensitive")
flag.BoolVar(&opts.SpaceSensitive, "space-sensitive", false, "make the enum space sensitive")
flag.BoolVar(&opts.NoTests, "no-tests", false, "skip testing code generation")
flag.BoolVar(&opts.NoParser, "no-parser", false, "skip parser code generation")
flag.BoolVar(&opts.NoStringer, "no-stringer", false, "skip Stringer interface code generation")
Expand Down
303 changes: 253 additions & 50 deletions enumify.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,58 +24,27 @@ type Enum interface {
// Options defines how the code generation should be performed. These values are set by
// the CLI flags passed via the go generate directive.
type Options struct {
File string // The file that kicked off the generation (from $GOFILE)
Pkg string // The package that the enum is in (from $GOPACKAGE)
NameVar string // The variable name that contains the names of the enum values
NoTests bool // Whether to skip the testing code generation (defaults to false)
NoParser bool // Whether to skip the parser code generation (defaults to false)
NoStringer bool // Whether to skip the Stringer interface code generation (defaults to false)
NoText bool // Whether to skip the text interfaces code generation (defaults to false)
NoBinary bool // Whether to skip the binary interfaces code generation (defaults to false)
NoJSON bool // Whether to skip the JSON interfaces code generation (defaults to false)
NoYAML bool // Whether to skip the YAML interfaces code generation (defaults to false)
NoSQL bool // Whether to skip the SQL interfaces code generation (defaults to false)
File string // The file that kicked off the generation (from $GOFILE)
Pkg string // The package that the enum is in (from $GOPACKAGE)
NameVar string // The variable name that contains the names of the enum values
CaseSensitive bool // Whether to make the enum case sensitive (defaults to false)
SpaceSensitive bool // Whether to make the enum space sensitive (defaults to false)
NoTests bool // Whether to skip the testing code generation (defaults to false)
NoParser bool // Whether to skip the parser code generation (defaults to false)
NoStringer bool // Whether to skip the Stringer interface code generation (defaults to false)
NoText bool // Whether to skip the text interfaces code generation (defaults to false)
NoBinary bool // Whether to skip the binary interfaces code generation (defaults to false)
NoJSON bool // Whether to skip the JSON interfaces code generation (defaults to false)
NoYAML bool // Whether to skip the YAML interfaces code generation (defaults to false)
NoSQL bool // Whether to skip the SQL interfaces code generation (defaults to false)
}

func GenerateComment() string {
return fmt.Sprintf("Code generated by enumify v%s. DO NOT EDIT.", Version())
}

func Generate(opts Options) (err error) {
f := g.NewFile(opts.Pkg)
f.PackageComment(GenerateComment())

if _, err = opts.discover(); err != nil {
return err
}

if err = f.Save(opts.fileName()); err != nil {
return err
}
return nil
}

func GenerateTests(opts Options) (err error) {
f := g.NewFile(opts.Pkg + "_test")
f.PackageComment(GenerateComment())

if err = f.Save(opts.testFileName()); err != nil {
return err
}
return nil
}

func (o Options) fileName() string {
ext := filepath.Ext(o.File)
return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen" + ext
}

func (o Options) testFileName() string {
ext := filepath.Ext(o.File)
return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen_test" + ext
}

func (o *Options) discover() (etypes EnumTypes, err error) {
// Discover is the entry point for all code generation. It will parse the file specified
// by the go generate directive into an AST, then loop through all of the data in the
// file to discover the Enum types (anything that extends uint8), the constant values
// for the enum types, and the names variable based on the name pattern, the passed in
// name, and the allowed name var types.
func Discover(o Options) (etypes EnumTypes, err error) {
// Build tool package discovery configuration.
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo,
Expand Down Expand Up @@ -150,3 +119,237 @@ func (o *Options) discover() (etypes EnumTypes, err error) {
}
return etypes, nil
}

//============================================================================
// Code Generation Entry Points
//============================================================================

// The generation comment is considered best practice when writing golang code generators.
// In addition to the standard warning, it also includes the version of the generator that
// was used to generate the code for debugging purposes.
func GenerateComment() string {
return fmt.Sprintf("Code generated by enumify v%s. DO NOT EDIT.", Version())
}

// Generate the code file for the enum types.
func Generate(opts Options) (err error) {
f := g.NewFile(opts.Pkg)
f.PackageComment(GenerateComment())

// Manage import names
f.ImportName("go.rtnl.ai/enumify", "enumify")
f.ImportName("encoding/json", "json")
f.ImportName("database/sql/driver", "driver")

var etypes EnumTypes
if etypes, err = Discover(opts); err != nil {
return err
}

// Write the code for the enum types.
for _, etype := range etypes {
evar := g.Id("err") // err variable name
s := etype.ReceiverId() // e if Enum or s if Status etc.
ptrS := g.Op("*").Add(etype.ReceiverId()) // *e if Enum or *s if Status etc.
typeID := etype.Id() // Enum
namesVar := etype.NamesVarId() // enumNames
namesVarSlice := etype.NamesVarSliceId() // enumNames or enumNames[0]

f.Comment("//============================================================================")
f.Comment("// " + etype.Name + " Enum Type: Generated Functions and Methods")
f.Comment("//============================================================================")
f.Line()

// Write the parser code for the enum type
if !opts.NoParser {
// TODO: this uses the factory function to create the parser function. This
// is not ideal because it creates a dependency on the go.rtnl.ai/enumify
// package. See https://github.com/rotationalio/enumify/issues/7 for
// ideas on how to improve this.
parserVar := g.Id(LowerFirst(etype.Name) + "Parser")
parserType := g.Func().Params(g.Any()).Params(typeID, g.Error())
f.Var().Add(parserVar, parserType).Op("=").Qual("go.rtnl.ai/enumify", "ParseFactory").Types(typeID).Call(namesVar)
f.Line()

f.Commentf("Parse%s parses the given value into a %s.", etype.Name, etype.Name)
parserFunc := f.Func().Id("Parse" + UpperFirst(etype.Name))
parserFunc.Params(s.Clone().Any()).Params(typeID, g.Error()).Block(
g.Return(parserVar.Clone().Call(s.Clone())),
)
f.Line()
}

methodSig := g.Func().Params(s.Clone().Add(typeID))
methodPtrSig := g.Func().Params(s.Clone().Op("*").Add(typeID))

// Write the stringer code for the enum type
if !opts.NoStringer {
f.Commentf("Ensure %s implements fmt.Stringer.", etype.Name)

method := methodSig.Clone().Id("String").Call().String()
method.Block(
g.If(s.Clone().Op(">=").Add(typeID).Call(g.Len(namesVarSlice))).Block(
g.Return(etype.IndexNames(etype.ZeroConstId())),
),
g.Return(etype.IndexNames(s)),
)
f.Add(method)
f.Line()
}

// JSON Marshal and Unmarshal code
if !opts.NoJSON {
f.Commentf("Ensure %s implements json.Marshaler.", etype.Name)
method := methodSig.Clone().Id("MarshalJSON").Call().Params(g.Id("[]byte"), g.Error())
method.Block(
g.Return(g.Qual("encoding/json", "Marshal").Call(s.Clone().Dot("String").Call())),
)
f.Add(method)
f.Line()

f.Commentf("Ensure %s implements json.Unmarshaler.", etype.Name)

sv := g.Id("sv")
method = methodPtrSig.Clone().Id("UnmarshalJSON").Call(g.Id("data").Id("[]byte")).Params(evar.Clone().Error())
method.Block(
g.Var().Add(sv).Any(),
g.If(
evar.Clone().Op("=").Qual("encoding/json", "Unmarshal").Call(g.Id("data"), g.Op("&").Add(sv)).Op(";").Id("err").Op("!=").Nil()).
Block(
g.Return(evar),
),
g.Line(),
g.If(
ptrS.Clone().Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil().
Block(
g.Return(evar),
),
),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnmarshalJSON references parser not generated when NoParser set

Medium Severity

The generated UnmarshalJSON and MarshalJSON methods implicitly depend on other generated code, specifically Parse<TypeName> and String() respectively. If NoJSON is false but NoParser or NoStringer are true, the generated code will reference undefined functions, causing compilation failures.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 862b2c1. Configure here.

g.Return().Nil(),
)
f.Add(method)
f.Line()
}

// YAML Marshal and Unmarshal code
if !opts.NoYAML {
f.Commentf("Ensure %s implements yaml.Marshaler.", etype.Name)
method := methodSig.Clone().Id("MarshalYAML").Call().Params(g.Any(), g.Error())
method.Block(
g.Return(s.Clone().Dot("String").Call(), g.Nil()),
)
f.Add(method)
f.Line()

f.Commentf("Ensure %s implements yaml.Unmarshaler.", etype.Name)

sv := g.Id("sv")
method = methodPtrSig.Clone().Id("UnmarshalYAML").Call(g.Id("unmarshal").Func().Params(g.Any()).Params(g.Error())).Params(evar.Clone().Error())
method.Block(
g.Var().Add(sv).String(),
g.If(evar.Clone().Op("=").Id("unmarshal").Call(g.Op("&").Add(sv)).Op(";").Id("err").Op("!=").Nil()).Block(
g.Return(evar),
),
g.Line(),
g.If(g.Add(ptrS.Clone()).Op(",").Id("err").Op("=").Id("Parse"+etype.Name).Call(sv).Op(";").Id("err").Op("!=").Nil().
Block(
g.Return(evar),
),
),
g.Return(g.Nil()),
)
f.Add(method)
f.Line()
}

// SQL Scanner and Valuer code
if !opts.NoSQL {
f.Commentf("Ensure %s implements sql.Scanner.", etype.Name)
val := g.Id("val")
method := methodPtrSig.Clone().Id("Scan").Call(g.Id("src").Any()).Params(evar.Clone().Error())
method.Block(
g.Switch(val.Clone().Op(":=").Id("src").Assert(g.Type())).Block(
g.Case(g.Nil()).Block(g.Return(g.Nil())),
g.Case(g.String()).Block(
g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(val)),
g.Return(evar),
),
g.Case(g.Id("[]byte")).Block(
g.List(ptrS.Clone(), evar.Clone().Op("=").Id("Parse"+etype.Name).Call(g.String().Call(val))),
g.Return(evar),
),
g.Default().Block(
g.Return(g.Qual("fmt", "Errorf").Call(g.Lit("cannot scan %T into "+etype.Name), val)),
),
),
)

f.Add(method)
f.Line()

f.Commentf("Ensure %s implements driver.Valuer.", etype.Name)
method = methodSig.Clone().Id("Value").Call().Params(g.Qual("database/sql/driver", "Value"), g.Error())
method.Block(
g.Return(s.Clone().Dot("String").Call(), g.Nil()),
)
f.Add(method)
f.Line()
}

}

if err = f.Save(opts.fileName()); err != nil {
return err
}
return nil
}

// Generate the test file for the enum types.
func GenerateTests(opts Options) (err error) {
f := g.NewFile(opts.Pkg)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test file generated with wrong package name suffix

Medium Severity

GenerateTests uses g.NewFile(opts.Pkg) instead of g.NewFile(opts.Pkg + "_test"), generating white-box tests in the same package. While this works for accessing unexported names vars, it means the generated test file (saved as *_gen_test.go) shares the package namespace with the source, which can cause symbol collisions with other test files in the same package that use package <name>_test (black-box testing). The old code used opts.Pkg + "_test" — if the intent was to switch to white-box tests, the generated test code also needs to avoid referencing enumify.TestSuite with a qualified import since it's now in the same compilation context as the source.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 862b2c1. Configure here.

f.PackageComment(GenerateComment())

// Manage import names
f.ImportName("testing", "testing")
f.ImportName("go.rtnl.ai/enumify", "enumify")

// Discover the enum types in the package.
var etypes EnumTypes
if etypes, err = Discover(opts); err != nil {
return err
}

// For each enum type, write a test function that creates and executes an enumify
// test suite for the enum type.
for _, etype := range etypes {
f.Func().Id("Test"+UpperFirst(etype.Name)).Add(TestingT).Block(
g.Id("suite").Op(":=").Qual("go.rtnl.ai/enumify", "TestSuite").Types(etype.Id(), etype.NamesVarTypeId()).Block(
g.Id("Values").Op(":").Add(etype.ConstLiteral()).Op(","),
g.Id("Names").Op(":").Add(etype.NamesVarId()).Op(","),
g.Id("ICase").Op(":").Lit(!opts.CaseSensitive).Op(","),
g.Id("ISpace").Op(":").Lit(!opts.SpaceSensitive).Op(","),
),
g.Line(),
g.Id("suite").Dot("Run").Call(g.Id("t")),
).Line()
}

if err = f.Save(opts.testFileName()); err != nil {
return err
}
return nil
}

//============================================================================
// Options Methods
//============================================================================

func (o Options) fileName() string {
ext := filepath.Ext(o.File)
return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen" + ext
}

func (o Options) testFileName() string {
ext := filepath.Ext(o.File)
return strings.TrimSuffix(filepath.Base(o.File), ext) + "_gen_test" + ext
}
4 changes: 2 additions & 2 deletions example/calendar.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const (
// and because it matches the dayNames pattern to connect it with the Day enum.
//
//lint:ignore U1000 this is used by the enumify generator
var dayNames = [...]string{
var dayNames = []string{
"unknown",
"Monday",
"Tuesday",
Expand Down Expand Up @@ -61,7 +61,7 @@ const (
// having to specify it using the -names flag.
//
//lint:ignore U1000 this is used by the enumify generator
var monthNames = [2][13]string{
var monthNames = [][]string{
{"", "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"},
{"", "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"},
}
Expand Down
9 changes: 9 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package enumify

var (
IsStringArray = isStringArray
IsStringSlice = isStringSlice
IsStringTable = isStringTable
IsStrings2DArray = isStrings2DArray
IsNamesType = isNamesType
)
31 changes: 31 additions & 0 deletions strings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package enumify

import (
"unicode"
"unicode/utf8"
)

func LowerFirst(s string) string {
if s == "" {
return ""
}

r, size := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return s
}

return string(unicode.ToLower(r)) + s[size:]
}

func UpperFirst(s string) string {
if s == "" {
return ""
}

r, size := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return s
}
return string(unicode.ToUpper(r)) + s[size:]
}
Loading
Loading