diff --git a/README.md b/README.md index a77eaf8..233f7d4 100644 --- a/README.md +++ b/README.md @@ -240,23 +240,6 @@ Custom: foo.bar%2Fqux=XYZ (`%5C` and `%2F` represent `\` and `/`, respectively.) -Limitations ------------ - - - Circular (self-referential) values are untested. - -Future Work ------------ - -The following items would be nice to have in the future—though they are not being worked on yet: - - - An option to automatically treat all field names in `camelCase` or `underscore_case`. - - Built-in support for the types in [`math/big`](http://golang.org/pkg/math/big/). - - Built-in support for the types in [`image/color`](http://golang.org/pkg/image/color/). - - Improve encoding/decoding by reading/writing directly from/to the `io.Reader`/`io.Writer` when possible, rather than going through an intermediate representation (i.e. `node`) which requires more memory. - -(Feel free to implement any of these and then send a pull request.) - Related Work ------------ diff --git a/TODO.md b/TODO.md index d344727..76c7107 100644 --- a/TODO.md +++ b/TODO.md @@ -2,3 +2,7 @@ TODO ==== - Document IgnoreCase and IgnoreUnknownKeys in README. + - An option to automatically treat all field names in `camelCase` or `underscore_case`. + - Built-in support for the types in [`math/big`](http://golang.org/pkg/math/big/). + - Built-in support for the types in [`image/color`](http://golang.org/pkg/image/color/). + - Improve encoding/decoding by reading/writing directly from/to the `io.Reader`/`io.Writer` when possible, rather than going through an intermediate representation (i.e. `node`) which requires more memory. diff --git a/encode.go b/encode.go index de362a5..4c93f19 100644 --- a/encode.go +++ b/encode.go @@ -118,10 +118,11 @@ func encodeToNode(v reflect.Value, z bool, o bool) (n node, err error) { err = fmt.Errorf("%v", e) } }() - return getNode(encodeValue(v, z, o)), nil + seen := make(map[uintptr]bool) + return getNode(encodeValue(v, z, o, seen)), nil } -func encodeValue(v reflect.Value, z bool, o bool) interface{} { +func encodeValue(v reflect.Value, z bool, o bool, seen map[uintptr]bool) interface{} { t := v.Type() k := v.Kind() @@ -132,21 +133,35 @@ func encodeValue(v reflect.Value, z bool, o bool) interface{} { } switch k { - case reflect.Ptr, reflect.Interface: - return encodeValue(v.Elem(), z, o) + case reflect.Ptr: + ptr := v.Pointer() + if seen[ptr] { + panic("form: encoding a cycle via " + t.String()) + } + seen[ptr] = true + defer delete(seen, ptr) + return encodeValue(v.Elem(), z, o, seen) + case reflect.Interface: + return encodeValue(v.Elem(), z, o, seen) case reflect.Struct: if t.ConvertibleTo(timeType) { return encodeTime(v) } else if t.ConvertibleTo(urlType) { return encodeURL(v) } - return encodeStruct(v, z, o) + return encodeStruct(v, z, o, seen) case reflect.Slice: - return encodeSlice(v, z, o) + return encodeSlice(v, z, o, seen) case reflect.Array: - return encodeArray(v, z, o) + return encodeArray(v, z, o, seen) case reflect.Map: - return encodeMap(v, z, o) + ptr := v.Pointer() + if seen[ptr] { + panic("form: encoding a cycle via " + t.String()) + } + seen[ptr] = true + defer delete(seen, ptr) + return encodeMap(v, z, o, seen) case reflect.Invalid, reflect.Uintptr, reflect.UnsafePointer, reflect.Chan, reflect.Func: panic(t.String() + " has unsupported kind " + t.Kind().String()) default: @@ -160,7 +175,7 @@ type encoderField struct { omitempty bool } -func encodeStruct(v reflect.Value, z bool, o bool) interface{} { +func encodeStruct(v reflect.Value, z bool, o bool, seen map[uintptr]bool) interface{} { fields := collectFields(v.Type()) n := node{} for _, f := range fields { @@ -171,7 +186,7 @@ func encodeStruct(v reflect.Value, z bool, o bool) interface{} { if (o || f.omitempty) && isEmptyValue(fv) { continue } - n[f.name] = encodeValue(fv, z, o) + n[f.name] = encodeValue(fv, z, o, seen) } return n } @@ -326,31 +341,31 @@ func isLeafStruct(ft reflect.Type) bool { return ft.Implements(textMarshalerType) || reflect.PtrTo(ft).Implements(textMarshalerType) } -func encodeMap(v reflect.Value, z bool, o bool) interface{} { +func encodeMap(v reflect.Value, z bool, o bool, seen map[uintptr]bool) interface{} { n := node{} for _, i := range v.MapKeys() { - k := getString(encodeValue(i, z, o)) - n[k] = encodeValue(v.MapIndex(i), z, o) + k := getString(encodeValue(i, z, o, seen)) + n[k] = encodeValue(v.MapIndex(i), z, o, seen) } return n } -func encodeArray(v reflect.Value, z bool, o bool) interface{} { +func encodeArray(v reflect.Value, z bool, o bool, seen map[uintptr]bool) interface{} { n := node{} for i := 0; i < v.Len(); i++ { - n[strconv.Itoa(i)] = encodeValue(v.Index(i), z, o) + n[strconv.Itoa(i)] = encodeValue(v.Index(i), z, o, seen) } return n } -func encodeSlice(v reflect.Value, z bool, o bool) interface{} { +func encodeSlice(v reflect.Value, z bool, o bool, seen map[uintptr]bool) interface{} { t := v.Type() if t.Elem().Kind() == reflect.Uint8 { return string(v.Bytes()) // Encode byte slices as a single string by default. } n := node{} for i := 0; i < v.Len(); i++ { - n[strconv.Itoa(i)] = encodeValue(v.Index(i), z, o) + n[strconv.Itoa(i)] = encodeValue(v.Index(i), z, o, seen) } return n } diff --git a/encode_test.go b/encode_test.go index 3246f33..6b79fb5 100644 --- a/encode_test.go +++ b/encode_test.go @@ -137,6 +137,46 @@ func TestEncode_OmitEmpty(t *testing.T) { } } +func TestEncode_Cycle(t *testing.T) { + t.Run("self-referential struct pointer", func(t *testing.T) { + type Cyclic struct { + Name string + Next *Cyclic + } + a := &Cyclic{Name: "a"} + a.Next = a + + if _, err := EncodeToString(a); err == nil { + t.Error("expected error for cyclic struct pointer, got nil") + } + }) + + t.Run("map containing itself", func(t *testing.T) { + m := map[string]interface{}{} + m["self"] = m + + if _, err := EncodeToString(m); err == nil { + t.Error("expected error for cyclic map, got nil") + } + }) + + t.Run("non-cyclic pointer sharing (DAG)", func(t *testing.T) { + type Node struct { + Value string + } + type DAG struct { + A *Node + B *Node + } + shared := &Node{Value: "shared"} + dag := DAG{A: shared, B: shared} + + if _, err := EncodeToString(dag); err != nil { + t.Errorf("unexpected error for DAG: %s", err) + } + }) +} + func TestEncode_ConflictResolution(t *testing.T) { for _, c := range []struct { name string