From 1156df24cb3519cc6c50ebdcc10566b3ccd35fea Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Fri, 3 Apr 2026 14:02:20 -0500 Subject: [PATCH] [FEAT] Parse Generic --- README.md | 27 +++++++ enumify.go | 15 ++++ enumify_test.go | 29 ++++++++ go.mod | 8 ++ go.sum | 10 +++ parse.go | 141 +++++++++++++++++++++++++++++++++++ parse_test.go | 191 ++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 421 insertions(+) create mode 100644 enumify_test.go create mode 100644 go.sum create mode 100644 parse.go create mode 100644 parse_test.go diff --git a/README.md b/README.md index ab8a37b..8d68f2c 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,33 @@ We found ourselves generating a lot of boilerplate code for our Enums: particularly for parsing, serialization/deserialization, database storage, and tests. Really we just want to be able to describe an Enum and get all of this code for free! Enumify is the combination of a code generator for the boilerplate code as well as a package for reducing the boilerplate and ensuring that everything works as expected. +## Getting Started + +Install enumify: + +```sh +go install go.rtnl.ai/enumify/cmd/enumify@latest +``` + +This will add the `enumify` command to your `$PATH`. + +Next steps: + +1. Define enum schema file +2. Create generate command +3. Run go generate ./... +4. Run go mod tidy + +Boom - your package is ready to go with enums! + +## Theory + +For us, an `enum` is a set of specific values that have string representations. In code we want our `enum` values to be `uint8` (if you have more than 255 values you probably want a database table instead of an `enum`) for compact memory usage and ease of comparisons. However for marshaling/unmarshaling the value to json/yaml or saving the value in a database, we'd like to use the string representation instead. + +However, the boilerplate code for implementing `fmt.Stringer`, `json.Marshaler`, `json.Unmarshaler`, etc. is pretty verbose -- not to mention the testing. This library unifies all of that functionality into generic functions and also provides code generation to make it simple to define enumerations. Modifications require a change to the enum schema file and running `go generate ./...` and presto chango, a minimal set of code is added to your project, but code that is fully tested and organized. + +Key theory point: the value behind the enum shouldn't matter, just its ordering. The first value (0) should always be the "unknown" or "default" value. You can compare enums but you cannot have them be specific runes or other values. + ## License This project is licensed under the BSD 3-Clause License. See [`LICENSE`](./LICENSE) for details. Please feel free to use Enumify in your own projects and applications. diff --git a/enumify.go b/enumify.go index 9ea6f0c..bf0fedc 100644 --- a/enumify.go +++ b/enumify.go @@ -1 +1,16 @@ package enumify + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" +) + +type Enum interface { + fmt.Stringer + json.Marshaler + json.Unmarshaler + sql.Scanner + driver.Valuer +} diff --git a/enumify_test.go b/enumify_test.go new file mode 100644 index 0000000..ca56fc4 --- /dev/null +++ b/enumify_test.go @@ -0,0 +1,29 @@ +package enumify_test + +//============================================================================ +// Test Enum Type +//============================================================================ + +type Status uint8 + +const ( + StatusUnknown Status = iota + StatusDraft + StatusReview + StatusPublished + StatusArchived +) + +var StatusNames = []string{ + "unknown", + "draft", + "review", + "published", + "archived", +} + +var StatusNames2D = [][]string{ + {"unknown", "draft", "review", "published", "archived"}, + {"Unknown", "Draft", "Needs Review", "Published", "Archived"}, + {"Unbekannt", "Entwurf", "Überprüfung", "Veröffentlicht", "Archiviert"}, +} diff --git a/go.mod b/go.mod index fe122b4..efddd5c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module go.rtnl.ai/enumify go 1.26.1 + +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c4c1710 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parse.go b/parse.go new file mode 100644 index 0000000..791b38b --- /dev/null +++ b/parse.go @@ -0,0 +1,141 @@ +package enumify + +import ( + "fmt" + "strings" +) + +// Generic parser function that can be used to parse a specific enum type. +type Parser[T ~uint8] func(any) (T, error) + +// ParseFactory creates a parser function for a specific enum type. The parser can parse +// the enum from a string or a numeric type that can be converted into a uint8. In order +// to support string parsing, a names array is required. It can either be a single +// array where the index of the name is the value of the enum, or a 2D array where the +// first column contains the names indexed according to the enum value. +// +// NOTE: string parsing is case-insensitive and leading and trailing whitespace is +// ignored. Names should be slug values to ensure consistent parsing. +func ParseFactory[T ~uint8, Names []string | [][]string](names Names) Parser[T] { + var normalizedNames []string + switch col := any(names).(type) { + case []string: + if len(col) < 1 { + panic(fmt.Errorf("names array must contain at least one name")) + } + + normalizedNames = make([]string, len(col)) + for i, name := range col { + normalizedNames[i] = normalize(name) + } + case [][]string: + if len(col) < 1 { + panic(fmt.Errorf("names array must contain at least one column")) + } + + if len(col[0]) < 1 { + panic(fmt.Errorf("names array must contain at least one name")) + } + + normalizedNames = make([]string, len(col[0])) + for i, name := range col[0] { + normalizedNames[i] = normalize(name) + } + } + + // The "unknown" value is the zero-valued T. + unknown := T(0) + + return func(val any) (T, error) { + switch v := val.(type) { + case string: + v = normalize(v) + + // For an empty string, return the "unknown" or zero-valued T. + if v == "" { + return unknown, nil + } + + // Iterate over the normalized names and return the value of the first match. + for i, name := range normalizedNames { + if name == v { + return T(i), nil + } + } + + // If no match is found, return an error. + return unknown, fmt.Errorf("invalid %T value: %q", unknown, v) + case T: + if v >= T(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return v, nil + case uint: + if v >= uint(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case uint8: + if v >= uint8(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case uint16: + if v >= uint16(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case uint32: + if v >= uint32(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case uint64: + if v >= uint64(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case int: + if v < 0 || v >= len(normalizedNames) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case int8: + if v < 0 || v >= int8(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case int16: + if v < 0 || v >= int16(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case int32: + if v < 0 || v >= int32(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case int64: + if v < 0 || v >= int64(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %d", unknown, v) + } + return T(v), nil + case float32: + if v < 0 || v >= float32(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %f", unknown, v) + } + return T(v), nil + case float64: + if v < 0 || v >= float64(len(normalizedNames)) { + return unknown, fmt.Errorf("invalid %T value: %f", unknown, v) + } + return T(v), nil + default: + return unknown, fmt.Errorf("cannot parse %T into %T", v, unknown) + } + } +} + +func normalize(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..e5b7d3f --- /dev/null +++ b/parse_test.go @@ -0,0 +1,191 @@ +package enumify_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "go.rtnl.ai/enumify" +) + +func TestParseFactory(t *testing.T) { + validTestCases := []struct { + value any + expected Status + }{ + {value: "", expected: StatusUnknown}, + {value: "unknown", expected: StatusUnknown}, + {value: "draft", expected: StatusDraft}, + {value: "review", expected: StatusReview}, + {value: "published", expected: StatusPublished}, + {value: "archived", expected: StatusArchived}, + {value: " ", expected: StatusUnknown}, + {value: "\n", expected: StatusUnknown}, + {value: "\t", expected: StatusUnknown}, + {value: " unknown ", expected: StatusUnknown}, + {value: "draft ", expected: StatusDraft}, + {value: " review", expected: StatusReview}, + {value: "\tpublished\t", expected: StatusPublished}, + {value: "\narchived\t", expected: StatusArchived}, + {value: "UNKNOWN", expected: StatusUnknown}, + {value: "DRAFT", expected: StatusDraft}, + {value: "REVIEW", expected: StatusReview}, + {value: "PUBLISHED", expected: StatusPublished}, + {value: "ARCHIVED", expected: StatusArchived}, + {value: uint(0), expected: StatusUnknown}, + {value: uint(1), expected: StatusDraft}, + {value: uint(2), expected: StatusReview}, + {value: uint(3), expected: StatusPublished}, + {value: uint(4), expected: StatusArchived}, + {value: uint8(0), expected: StatusUnknown}, + {value: uint8(1), expected: StatusDraft}, + {value: uint8(2), expected: StatusReview}, + {value: uint8(3), expected: StatusPublished}, + {value: uint8(4), expected: StatusArchived}, + {value: uint16(0), expected: StatusUnknown}, + {value: uint16(1), expected: StatusDraft}, + {value: uint16(2), expected: StatusReview}, + {value: uint16(3), expected: StatusPublished}, + {value: uint16(4), expected: StatusArchived}, + {value: uint32(0), expected: StatusUnknown}, + {value: uint32(1), expected: StatusDraft}, + {value: uint32(2), expected: StatusReview}, + {value: uint32(3), expected: StatusPublished}, + {value: uint32(4), expected: StatusArchived}, + {value: uint64(0), expected: StatusUnknown}, + {value: uint64(1), expected: StatusDraft}, + {value: uint64(2), expected: StatusReview}, + {value: uint64(3), expected: StatusPublished}, + {value: uint64(4), expected: StatusArchived}, + {value: int(0), expected: StatusUnknown}, + {value: int(1), expected: StatusDraft}, + {value: int(2), expected: StatusReview}, + {value: int(3), expected: StatusPublished}, + {value: int(4), expected: StatusArchived}, + {value: int8(0), expected: StatusUnknown}, + {value: int8(1), expected: StatusDraft}, + {value: int8(2), expected: StatusReview}, + {value: int8(3), expected: StatusPublished}, + {value: int8(4), expected: StatusArchived}, + {value: int16(0), expected: StatusUnknown}, + {value: int16(1), expected: StatusDraft}, + {value: int16(2), expected: StatusReview}, + {value: int16(3), expected: StatusPublished}, + {value: int16(4), expected: StatusArchived}, + {value: int32(0), expected: StatusUnknown}, + {value: int32(1), expected: StatusDraft}, + {value: int32(2), expected: StatusReview}, + {value: int32(3), expected: StatusPublished}, + {value: int32(4), expected: StatusArchived}, + {value: int64(0), expected: StatusUnknown}, + {value: int64(1), expected: StatusDraft}, + {value: int64(2), expected: StatusReview}, + {value: int64(3), expected: StatusPublished}, + {value: int64(4), expected: StatusArchived}, + {value: float32(0), expected: StatusUnknown}, + {value: float32(1), expected: StatusDraft}, + {value: float32(2), expected: StatusReview}, + {value: float32(3), expected: StatusPublished}, + {value: float32(4), expected: StatusArchived}, + {value: float64(0), expected: StatusUnknown}, + {value: float64(1), expected: StatusDraft}, + {value: float64(2), expected: StatusReview}, + {value: float64(3), expected: StatusPublished}, + {value: float64(4), expected: StatusArchived}, + {value: StatusUnknown, expected: StatusUnknown}, + {value: StatusDraft, expected: StatusDraft}, + {value: StatusReview, expected: StatusReview}, + {value: StatusPublished, expected: StatusPublished}, + {value: StatusArchived, expected: StatusArchived}, + {value: Status(0), expected: StatusUnknown}, + {value: Status(1), expected: StatusDraft}, + {value: Status(2), expected: StatusReview}, + {value: Status(3), expected: StatusPublished}, + {value: Status(4), expected: StatusArchived}, + } + + invalidTestCases := []struct { + value any + expected string + }{ + {value: "foo", expected: `invalid enumify_test.Status value: "foo"`}, + {value: "Unbekannt", expected: `invalid enumify_test.Status value: "unbekannt"`}, + {value: "Entwurf", expected: `invalid enumify_test.Status value: "entwurf"`}, + {value: "Überprüfung", expected: `invalid enumify_test.Status value: "überprüfung"`}, + {value: "Veröffentlicht", expected: `invalid enumify_test.Status value: "veröffentlicht"`}, + {value: "Archiviert", expected: `invalid enumify_test.Status value: "archiviert"`}, + {value: "Needs Review", expected: `invalid enumify_test.Status value: "needs review"`}, + {value: nil, expected: `cannot parse into enumify_test.Status`}, + {value: true, expected: `cannot parse bool into enumify_test.Status`}, + {value: Status(42), expected: `invalid enumify_test.Status value: 42`}, + {value: uint(42), expected: `invalid enumify_test.Status value: 42`}, + {value: uint8(42), expected: `invalid enumify_test.Status value: 42`}, + {value: uint16(42), expected: `invalid enumify_test.Status value: 42`}, + {value: uint32(42), expected: `invalid enumify_test.Status value: 42`}, + {value: uint64(42), expected: `invalid enumify_test.Status value: 42`}, + {value: int(-1), expected: `invalid enumify_test.Status value: -1`}, + {value: int(42), expected: `invalid enumify_test.Status value: 42`}, + {value: int8(-1), expected: `invalid enumify_test.Status value: -1`}, + {value: int8(42), expected: `invalid enumify_test.Status value: 42`}, + {value: int16(-1), expected: `invalid enumify_test.Status value: -1`}, + {value: int16(42), expected: `invalid enumify_test.Status value: 42`}, + {value: int32(-1), expected: `invalid enumify_test.Status value: -1`}, + {value: int32(42), expected: `invalid enumify_test.Status value: 42`}, + {value: int64(-1), expected: `invalid enumify_test.Status value: -1`}, + {value: int64(42), expected: `invalid enumify_test.Status value: 42`}, + {value: float32(-1), expected: `invalid enumify_test.Status value: -1.000000`}, + {value: float32(42), expected: `invalid enumify_test.Status value: 42.000000`}, + {value: float64(-1), expected: `invalid enumify_test.Status value: -1.000000`}, + {value: float64(42), expected: `invalid enumify_test.Status value: 42.000000`}, + } + + t.Run("Names", func(t *testing.T) { + parse := enumify.ParseFactory[Status](StatusNames) + for i, tc := range validTestCases { + t.Run(fmt.Sprintf("Valid/%02d", i+1), func(t *testing.T) { + actual, err := parse(tc.value) + require.NoError(t, err) + require.Equal(t, tc.expected, actual) + }) + } + + for i, tc := range invalidTestCases { + t.Run(fmt.Sprintf("Invalid/%02d", i+1), func(t *testing.T) { + actual, err := parse(tc.value) + require.EqualError(t, err, tc.expected) + require.Equal(t, StatusUnknown, actual) + }) + } + }) + + t.Run("Names2D", func(t *testing.T) { + parse := enumify.ParseFactory[Status](StatusNames2D) + for i, tc := range validTestCases { + t.Run(fmt.Sprintf("Valid/%02d", i+1), func(t *testing.T) { + actual, err := parse(tc.value) + require.NoError(t, err) + require.Equal(t, tc.expected, actual) + }) + } + + for i, tc := range invalidTestCases { + t.Run(fmt.Sprintf("Invalid/%02d", i+1), func(t *testing.T) { + actual, err := parse(tc.value) + require.EqualError(t, err, tc.expected) + require.Equal(t, StatusUnknown, actual) + }) + } + }) + + t.Run("InvalidNames", func(t *testing.T) { + require.Panics(t, func() { + enumify.ParseFactory[Status]([]string{}) + }) + require.Panics(t, func() { + enumify.ParseFactory[Status]([][]string{}) + }) + require.Panics(t, func() { + enumify.ParseFactory[Status]([][]string{{}, {}}) + }) + }) +}