diff --git a/examples/Taskfile.yml b/examples/Taskfile.yml index 497b25c..fc1edb6 100644 --- a/examples/Taskfile.yml +++ b/examples/Taskfile.yml @@ -11,6 +11,7 @@ tasks: cmd: >- APP_DATABASE_USERNAME=user APP_DATABASE_PASSWORD=pass + APP_TIME=2021-01-01T00:00:00Z go run main.go --debug --config=files/override.yaml diff --git a/examples/simple/expected.json b/examples/simple/expected.json index 3ad4dd4..26ed216 100644 --- a/examples/simple/expected.json +++ b/examples/simple/expected.json @@ -9,5 +9,6 @@ "username": "user", "password": "pass" }, - "Config": "files/override.yaml" + "Config": "files/override.yaml", + "time": "2021-01-01T00:00:00Z" } diff --git a/examples/simple/main.go b/examples/simple/main.go index cf9c0f1..44dff04 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -5,15 +5,17 @@ import ( "flag" "log" "os" + "time" "github.com/codetent/confless" ) type Config struct { - Name string `json:"name"` - Host string `json:"host"` - Port int `json:"port"` - Debug bool `json:"debug"` + Name string `json:"name"` + Host string `json:"host"` + Port int `json:"port"` + Debug bool `json:"debug"` + Time time.Time `json:"time"` Database struct { Host string `json:"host"` diff --git a/pkg/dotpath/marshal.go b/pkg/dotpath/marshal.go new file mode 100644 index 0000000..176a953 --- /dev/null +++ b/pkg/dotpath/marshal.go @@ -0,0 +1,35 @@ +package dotpath + +import ( + "encoding" + "fmt" + "reflect" +) + +// Unmarshal the value using a custom unmarshaler. +func unmarshalText(v reflect.Value, value any) error { + // Convert the value to bytes. + valueBytes, ok := value.([]byte) + if !ok { + valueString, ok := value.(string) + if !ok { + return fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind()) + } + + valueBytes = []byte(valueString) + } + + // Check if the value implements TextUnmarshaler. + unmarshaler, ok := v.Addr().Interface().(encoding.TextUnmarshaler) + if !ok { + return fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind()) + } + + // Unmarshal the value. + err := unmarshaler.UnmarshalText(valueBytes) + if err != nil { + return fmt.Errorf("%w: failed to unmarshal value: %w", ErrInvalidValue, err) + } + + return nil +} diff --git a/pkg/dotpath/reflect.go b/pkg/dotpath/reflect.go index a990793..aa1d517 100644 --- a/pkg/dotpath/reflect.go +++ b/pkg/dotpath/reflect.go @@ -1,7 +1,7 @@ package dotpath import ( - "encoding/json" + "errors" "fmt" "reflect" "strconv" @@ -10,6 +10,11 @@ import ( "github.com/spf13/cast" ) +var ( + ErrInvalidValue = errors.New("invalid value") + ErrUnsupportedType = errors.New("unsupported type") +) + // Extract names from tags. func namesFromTags(f reflect.StructField) []string { names := make([]string, 0, 2) @@ -80,7 +85,7 @@ func getValue(v reflect.Value, p string) (reflect.Value, error) { v = v.Index(index) default: - return reflect.Value{}, fmt.Errorf("unsupported type: %s", v.Kind()) + return reflect.Value{}, fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind()) } // Pop the first part of the path. @@ -98,22 +103,7 @@ func setValue(v reflect.Value, value any) error { // If the value is not settable, return an error. if !v.CanSet() { - return fmt.Errorf("value is not settable") - } - - // If the value is a json.Unmarshaler, use it to unmarshal the value. - unmarshaler, ok := v.Addr().Interface().(json.Unmarshaler) - if ok { - b, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("failed to marshal value: %w", err) - } - - err = unmarshaler.UnmarshalJSON(b) - if err != nil { - return fmt.Errorf("failed to unmarshal value: %w", err) - } - return nil + return fmt.Errorf("%w: value is not settable", ErrInvalidValue) } // Handle basic types. @@ -121,40 +111,41 @@ func setValue(v reflect.Value, value any) error { case reflect.String: c, err := cast.ToStringE(value) if err != nil { - return fmt.Errorf("failed to cast value: %w", err) + return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err) } v.SetString(c) case reflect.Bool: c, err := cast.ToBoolE(value) if err != nil { - return fmt.Errorf("failed to cast value: %w", err) + return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err) } v.SetBool(c) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: c, err := cast.ToInt64E(value) if err != nil { - return fmt.Errorf("failed to cast value: %w", err) + return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err) } v.SetInt(c) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: c, err := cast.ToUint64E(value) if err != nil { - return fmt.Errorf("failed to cast value: %w", err) + return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err) } v.SetUint(c) case reflect.Float32, reflect.Float64: c, err := cast.ToFloat64E(value) if err != nil { - return fmt.Errorf("failed to cast value: %w", err) + return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err) } v.SetFloat(c) default: - return fmt.Errorf("unsupported type: %s", v.Kind()) + // Try to unmarshal the value using the custom unmarshaler. + return unmarshalText(v, value) } return nil diff --git a/pkg/dotpath/reflect_test.go b/pkg/dotpath/reflect_test.go index 56147fc..df043dc 100644 --- a/pkg/dotpath/reflect_test.go +++ b/pkg/dotpath/reflect_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "reflect" "testing" + "time" ) // CustomUnmarshaler is a custom type that implements json.Unmarshaler for testing @@ -433,13 +434,12 @@ func Test_setValue(t *testing.T) { wantErr: true, }, { - name: "json.Unmarshaler", - v: reflect.ValueOf(new(CustomUnmarshaler)).Elem(), - value: "test", + name: "text unmarshaler", + v: reflect.ValueOf(new(time.Time)).Elem(), + value: "2021-01-01T00:00:00Z", validate: func(t *testing.T, v reflect.Value) { - c := v.Interface().(CustomUnmarshaler) - if c.Value != "unmarshaled:test" { - t.Errorf("got %v, want unmarshaled:test", c.Value) + if v.Interface().(time.Time).Unix() != 1609459200 { + t.Errorf("got %v, want 1609459200", v.Interface().(time.Time).Unix()) } }, }, diff --git a/pkg/merge/merge.go b/pkg/merge/merge.go new file mode 100644 index 0000000..36147f7 --- /dev/null +++ b/pkg/merge/merge.go @@ -0,0 +1,53 @@ +package merge + +import ( + "fmt" + "reflect" + + "dario.cat/mergo" +) + +// Transformer that detects zero values using the IsZero method (if defined). +// An example of a type that implements the IsZero method is time.Time. +type isZeroTransformer struct{} + +func (t *isZeroTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { + // Check if the type implements the IsZero method. + isZero, hasIsZero := typ.MethodByName("IsZero") + if !hasIsZero { + return nil + } + + return func(dst, src reflect.Value) error { + // Check if the destination value is settable. + if !dst.CanSet() { + return nil + } + + // Check if is zero and set the source value if it is. + result := isZero.Func.Call([]reflect.Value{src}) + if result[0].Bool() { + return nil + } + + // Set the source value. + dst.Set(src) + + return nil + } +} + +func Merge(dst any, src any) error { + err := mergo.Merge( + dst, + src, + mergo.WithOverride, + mergo.WithTypeCheck, + mergo.WithTransformers(&isZeroTransformer{}), + ) + if err != nil { + return fmt.Errorf("failed to merge: %w", err) + } + + return nil +} diff --git a/populate.go b/populate.go index 5bc71f0..ba8a5a7 100644 --- a/populate.go +++ b/populate.go @@ -9,10 +9,10 @@ import ( "reflect" "strings" - "dario.cat/mergo" "github.com/goccy/go-yaml" "github.com/codetent/confless/pkg/dotpath" + "github.com/codetent/confless/pkg/merge" "github.com/codetent/confless/pkg/reflectutil" ) @@ -111,7 +111,7 @@ func populateByFile(r io.Reader, format string, obj any) error { } // Merge the decoded object into the given object. - err := mergo.Merge(obj, decoded, mergo.WithOverride) + err := merge.Merge(obj, decoded) if err != nil { return fmt.Errorf("failed to merge: %w", err) }