Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@

We found ourselves generating a lot of boilerplate code for our Enums: particularly for parsing, serialization/deserialization, database storage, and tests. Really we just want to be able to describe an Enum and get all of this code for free! Enumify is the combination of a code generator for the boilerplate code as well as a package for reducing the boilerplate and ensuring that everything works as expected.

## Getting Started

Install enumify:

```sh
go install go.rtnl.ai/enumify/cmd/enumify@latest
```

This will add the `enumify` command to your `$PATH`.

Next steps:

1. Define enum schema file
2. Create generate command
3. Run go generate ./...
4. Run go mod tidy

Boom - your package is ready to go with enums!

## Theory

For us, an `enum` is a set of specific values that have string representations. In code we want our `enum` values to be `uint8` (if you have more than 255 values you probably want a database table instead of an `enum`) for compact memory usage and ease of comparisons. However for marshaling/unmarshaling the value to json/yaml or saving the value in a database, we'd like to use the string representation instead.

However, the boilerplate code for implementing `fmt.Stringer`, `json.Marshaler`, `json.Unmarshaler`, etc. is pretty verbose -- not to mention the testing. This library unifies all of that functionality into generic functions and also provides code generation to make it simple to define enumerations. Modifications require a change to the enum schema file and running `go generate ./...` and presto chango, a minimal set of code is added to your project, but code that is fully tested and organized.

Key theory point: the value behind the enum shouldn't matter, just its ordering. The first value (0) should always be the "unknown" or "default" value. You can compare enums but you cannot have them be specific runes or other values.

## License

This project is licensed under the BSD 3-Clause License. See [`LICENSE`](./LICENSE) for details. Please feel free to use Enumify in your own projects and applications.
Expand Down
15 changes: 15 additions & 0 deletions enumify.go
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
package enumify

import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
)

type Enum interface {
fmt.Stringer
json.Marshaler
json.Unmarshaler
sql.Scanner
driver.Valuer
}
29 changes: 29 additions & 0 deletions enumify_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package enumify_test

//============================================================================
// Test Enum Type
//============================================================================

type Status uint8

const (
StatusUnknown Status = iota
StatusDraft
StatusReview
StatusPublished
StatusArchived
)

var StatusNames = []string{
"unknown",
"draft",
"review",
"published",
"archived",
}

var StatusNames2D = [][]string{
{"unknown", "draft", "review", "published", "archived"},
{"Unknown", "Draft", "Needs Review", "Published", "Archived"},
{"Unbekannt", "Entwurf", "Überprüfung", "Veröffentlicht", "Archiviert"},
}
8 changes: 8 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
module go.rtnl.ai/enumify

go 1.26.1

require github.com/stretchr/testify v1.11.1

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
141 changes: 141 additions & 0 deletions parse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package enumify

import (
"fmt"
"strings"
)

// Generic parser function that can be used to parse a specific enum type.
type Parser[T ~uint8] func(any) (T, error)

// ParseFactory creates a parser function for a specific enum type. The parser can parse
// the enum from a string or a numeric type that can be converted into a uint8. In order
// to support string parsing, a names array is required. It can either be a single
// array where the index of the name is the value of the enum, or a 2D array where the
// first column contains the names indexed according to the enum value.
//
// NOTE: string parsing is case-insensitive and leading and trailing whitespace is
// ignored. Names should be slug values to ensure consistent parsing.
func ParseFactory[T ~uint8, Names []string | [][]string](names Names) Parser[T] {
var normalizedNames []string
switch col := any(names).(type) {
case []string:
if len(col) < 1 {
panic(fmt.Errorf("names array must contain at least one name"))
}

normalizedNames = make([]string, len(col))
for i, name := range col {
normalizedNames[i] = normalize(name)
}
case [][]string:
if len(col) < 1 {
panic(fmt.Errorf("names array must contain at least one column"))
}

if len(col[0]) < 1 {
panic(fmt.Errorf("names array must contain at least one name"))
}

normalizedNames = make([]string, len(col[0]))
for i, name := range col[0] {
normalizedNames[i] = normalize(name)
}
}

// The "unknown" value is the zero-valued T.
unknown := T(0)

return func(val any) (T, error) {
switch v := val.(type) {
case string:
v = normalize(v)

// For an empty string, return the "unknown" or zero-valued T.
if v == "" {
return unknown, nil
}

// Iterate over the normalized names and return the value of the first match.
for i, name := range normalizedNames {
if name == v {
return T(i), nil
}
}

// If no match is found, return an error.
return unknown, fmt.Errorf("invalid %T value: %q", unknown, v)
case T:
if v >= T(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return v, nil
case uint:
if v >= uint(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case uint8:
if v >= uint8(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case uint16:
if v >= uint16(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case uint32:
if v >= uint32(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case uint64:
if v >= uint64(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case int:
if v < 0 || v >= len(normalizedNames) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case int8:
if v < 0 || v >= int8(len(normalizedNames)) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integer overflow in int8 bounds check

Medium Severity

The cast int8(len(normalizedNames)) silently overflows when the names slice has 128 or more entries. Since T is ~uint8 and can represent up to 256 enum values, this is a reachable scenario. When overflowed (e.g., 130 names → int8(130) = −126), the comparison v >= -126 becomes true for all non-negative int8 values, incorrectly rejecting every valid input.

Fix in Cursor Fix in Web

return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case int16:
if v < 0 || v >= int16(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case int32:
if v < 0 || v >= int32(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case int64:
if v < 0 || v >= int64(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %d", unknown, v)
}
return T(v), nil
case float32:
if v < 0 || v >= float32(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %f", unknown, v)
}
return T(v), nil
case float64:
if v < 0 || v >= float64(len(normalizedNames)) {
return unknown, fmt.Errorf("invalid %T value: %f", unknown, v)
}
return T(v), nil
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float values silently truncated to enum integers

Low Severity

Non-integer float values (e.g., float64(2.7)) pass the bounds check and are silently truncated by T(v), mapping them to an unrelated enum value. There's no check that the float is actually a whole number, so a value like 2.9 is quietly accepted as enum value 2.

Fix in Cursor Fix in Web

default:
return unknown, fmt.Errorf("cannot parse %T into %T", v, unknown)
}
}
}

func normalize(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
Loading
Loading