Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tasks:
cmd: >-
APP_DATABASE_USERNAME=user
APP_DATABASE_PASSWORD=pass
APP_TIME=2021-01-01T00:00:00Z
go run main.go
--debug
--config=files/override.yaml
Expand Down
3 changes: 2 additions & 1 deletion examples/simple/expected.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
"username": "user",
"password": "pass"
},
"Config": "files/override.yaml"
"Config": "files/override.yaml",
"time": "2021-01-01T00:00:00Z"
}
10 changes: 6 additions & 4 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ import (
"flag"
"log"
"os"
"time"

"github.com/codetent/confless"
)

type Config struct {
Name string `json:"name"`
Host string `json:"host"`
Port int `json:"port"`
Debug bool `json:"debug"`
Name string `json:"name"`
Host string `json:"host"`
Port int `json:"port"`
Debug bool `json:"debug"`
Time time.Time `json:"time"`

Database struct {
Host string `json:"host"`
Expand Down
35 changes: 35 additions & 0 deletions pkg/dotpath/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dotpath

import (
"encoding"
"fmt"
"reflect"
)

// Unmarshal the value using a custom unmarshaler.
func unmarshalText(v reflect.Value, value any) error {
// Convert the value to bytes.
valueBytes, ok := value.([]byte)
if !ok {
valueString, ok := value.(string)
if !ok {
return fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind())
}

valueBytes = []byte(valueString)
}

// Check if the value implements TextUnmarshaler.
unmarshaler, ok := v.Addr().Interface().(encoding.TextUnmarshaler)
if !ok {
return fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind())
}

// Unmarshal the value.
err := unmarshaler.UnmarshalText(valueBytes)
if err != nil {
return fmt.Errorf("%w: failed to unmarshal value: %w", ErrInvalidValue, err)
}

return nil
}
39 changes: 15 additions & 24 deletions pkg/dotpath/reflect.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dotpath

import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
Expand All @@ -10,6 +10,11 @@ import (
"github.com/spf13/cast"
)

var (
ErrInvalidValue = errors.New("invalid value")
ErrUnsupportedType = errors.New("unsupported type")
)

// Extract names from tags.
func namesFromTags(f reflect.StructField) []string {
names := make([]string, 0, 2)
Expand Down Expand Up @@ -80,7 +85,7 @@ func getValue(v reflect.Value, p string) (reflect.Value, error) {

v = v.Index(index)
default:
return reflect.Value{}, fmt.Errorf("unsupported type: %s", v.Kind())
return reflect.Value{}, fmt.Errorf("%w: %s", ErrUnsupportedType, v.Kind())
}

// Pop the first part of the path.
Expand All @@ -98,63 +103,49 @@ func setValue(v reflect.Value, value any) error {

// If the value is not settable, return an error.
if !v.CanSet() {
return fmt.Errorf("value is not settable")
}

// If the value is a json.Unmarshaler, use it to unmarshal the value.
unmarshaler, ok := v.Addr().Interface().(json.Unmarshaler)
if ok {
b, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}

err = unmarshaler.UnmarshalJSON(b)
if err != nil {
return fmt.Errorf("failed to unmarshal value: %w", err)
}
return nil
return fmt.Errorf("%w: value is not settable", ErrInvalidValue)
}

// Handle basic types.
switch v.Kind() {
case reflect.String:
c, err := cast.ToStringE(value)
if err != nil {
return fmt.Errorf("failed to cast value: %w", err)
return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err)
}

v.SetString(c)
case reflect.Bool:
c, err := cast.ToBoolE(value)
if err != nil {
return fmt.Errorf("failed to cast value: %w", err)
return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err)
}

v.SetBool(c)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
c, err := cast.ToInt64E(value)
if err != nil {
return fmt.Errorf("failed to cast value: %w", err)
return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err)
}

v.SetInt(c)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
c, err := cast.ToUint64E(value)
if err != nil {
return fmt.Errorf("failed to cast value: %w", err)
return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err)
}

v.SetUint(c)
case reflect.Float32, reflect.Float64:
c, err := cast.ToFloat64E(value)
if err != nil {
return fmt.Errorf("failed to cast value: %w", err)
return fmt.Errorf("%w: failed to cast value: %w", ErrInvalidValue, err)
}

v.SetFloat(c)
default:
return fmt.Errorf("unsupported type: %s", v.Kind())
// Try to unmarshal the value using the custom unmarshaler.
return unmarshalText(v, value)
}

return nil
Expand Down
12 changes: 6 additions & 6 deletions pkg/dotpath/reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"reflect"
"testing"
"time"
)

// CustomUnmarshaler is a custom type that implements json.Unmarshaler for testing
Expand Down Expand Up @@ -433,13 +434,12 @@ func Test_setValue(t *testing.T) {
wantErr: true,
},
{
name: "json.Unmarshaler",
v: reflect.ValueOf(new(CustomUnmarshaler)).Elem(),
value: "test",
name: "text unmarshaler",
v: reflect.ValueOf(new(time.Time)).Elem(),
value: "2021-01-01T00:00:00Z",
validate: func(t *testing.T, v reflect.Value) {
c := v.Interface().(CustomUnmarshaler)
if c.Value != "unmarshaled:test" {
t.Errorf("got %v, want unmarshaled:test", c.Value)
if v.Interface().(time.Time).Unix() != 1609459200 {
t.Errorf("got %v, want 1609459200", v.Interface().(time.Time).Unix())
}
},
},
Expand Down
53 changes: 53 additions & 0 deletions pkg/merge/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package merge

import (
"fmt"
"reflect"

"dario.cat/mergo"
)

// Transformer that detects zero values using the IsZero method (if defined).
// An example of a type that implements the IsZero method is time.Time.
type isZeroTransformer struct{}

func (t *isZeroTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
// Check if the type implements the IsZero method.
isZero, hasIsZero := typ.MethodByName("IsZero")
if !hasIsZero {
return nil
}

return func(dst, src reflect.Value) error {
// Check if the destination value is settable.
if !dst.CanSet() {
return nil
}

// Check if is zero and set the source value if it is.
result := isZero.Func.Call([]reflect.Value{src})
if result[0].Bool() {
return nil
}

// Set the source value.
dst.Set(src)

return nil
}
}

func Merge(dst any, src any) error {
err := mergo.Merge(
dst,
src,
mergo.WithOverride,
mergo.WithTypeCheck,
mergo.WithTransformers(&isZeroTransformer{}),
)
if err != nil {
return fmt.Errorf("failed to merge: %w", err)
}

return nil
}
4 changes: 2 additions & 2 deletions populate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"reflect"
"strings"

"dario.cat/mergo"
"github.com/goccy/go-yaml"

"github.com/codetent/confless/pkg/dotpath"
"github.com/codetent/confless/pkg/merge"
"github.com/codetent/confless/pkg/reflectutil"
)

Expand Down Expand Up @@ -111,7 +111,7 @@ func populateByFile(r io.Reader, format string, obj any) error {
}

// Merge the decoded object into the given object.
err := mergo.Merge(obj, decoded, mergo.WithOverride)
err := merge.Merge(obj, decoded)
if err != nil {
return fmt.Errorf("failed to merge: %w", err)
}
Expand Down
Loading