diff --git a/Readme.md b/Readme.md index d5ae277..3bf6d3f 100644 --- a/Readme.md +++ b/Readme.md @@ -116,7 +116,6 @@ $ jaydiff --report --show-types --ignore-excess old.json new.json # Ideas -- Handle circular references properly - JayPatch - Have the diff lib support more types (Structs, interfaces (?), Arrays, ...) - JSON-ish output (instead of go-ish) diff --git a/diff/diff.go b/diff/diff.go index 18fab5f..7c9eb2c 100644 --- a/diff/diff.go +++ b/diff/diff.go @@ -34,14 +34,18 @@ type Differ interface { // // BUG(yazgazan): An infinite recursion is possible if the lhs and/or rhs objects are cyclic func Diff(lhs, rhs interface{}) (Differ, error) { + return diff(lhs, rhs, &visited{}) +} + +func diff(lhs, rhs interface{}, visited *visited) (Differ, error) { lhsVal := reflect.ValueOf(lhs) rhsVal := reflect.ValueOf(rhs) - if lhs == nil && rhs == nil { - return scalar{lhs, rhs}, nil + if d, ok := nilCheck(lhs, rhs); ok { + return d, nil } - if lhs == nil || rhs == nil { - return types{lhs, rhs}, nil + if err := visited.add(lhsVal, rhsVal); err != nil { + return types{lhs, rhs}, ErrCyclic } if lhsVal.Type().Comparable() && rhsVal.Type().Comparable() { return scalar{lhs, rhs}, nil @@ -49,16 +53,28 @@ func Diff(lhs, rhs interface{}) (Differ, error) { if lhsVal.Kind() != rhsVal.Kind() { return types{lhs, rhs}, nil } + if lhsVal.Kind() == reflect.Slice { - return newSlice(lhs, rhs) + return newSlice(lhs, rhs, visited) } if lhsVal.Kind() == reflect.Map { - return newMap(lhs, rhs) + return newMap(lhs, rhs, visited) } return types{lhs, rhs}, &ErrUnsupported{lhsVal.Type(), rhsVal.Type()} } +func nilCheck(lhs, rhs interface{}) (Differ, bool) { + if lhs == nil && rhs == nil { + return scalar{lhs, rhs}, true + } + if lhs == nil || rhs == nil { + return types{lhs, rhs}, true + } + + return nil, false +} + func (t Type) String() string { switch t { case Identical: @@ -95,3 +111,46 @@ func IsMissing(d Differ) bool { return true } } + +type visited struct { + LHS []uintptr + RHS []uintptr +} + +func (v *visited) add(lhs, rhs reflect.Value) error { + if canAddr(lhs) { + if inPointers(v.LHS, lhs) { + return ErrCyclic + } + v.LHS = append(v.LHS, lhs.Pointer()) + } + if canAddr(rhs) { + if inPointers(v.RHS, rhs) { + return ErrCyclic + } + v.RHS = append(v.RHS, rhs.Pointer()) + } + + return nil +} + +func inPointers(pointers []uintptr, val reflect.Value) bool { + for _, lhs := range pointers { + if lhs == val.Pointer() { + return true + } + } + + return false +} + +func canAddr(val reflect.Value) bool { + switch val.Kind() { + case reflect.Chan, reflect.Func, reflect.Map: + fallthrough + case reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + return true + } + + return false +} diff --git a/diff/diff_test.go b/diff/diff_test.go index 54156bf..8ff7a2d 100644 --- a/diff/diff_test.go +++ b/diff/diff_test.go @@ -223,7 +223,7 @@ func TestSlice(t *testing.T) { Type: ContentDiffer, }, } { - typ, err := newSlice(test.LHS, test.RHS) + typ, err := newSlice(test.LHS, test.RHS, &visited{}) if err != nil { t.Errorf("NewSlice(%+v, %+v): unexpected error: %q", test.LHS, test.RHS, err) @@ -238,7 +238,7 @@ func TestSlice(t *testing.T) { testStrings("TestSlice", t, test, ss, indented) } - invalid, err := newSlice(nil, nil) + invalid, err := newSlice(nil, nil, &visited{}) if invalidErr, ok := err.(errInvalidType); ok { if !strings.Contains(invalidErr.Error(), "nil") { t.Errorf("NewSlice(nil, nil): unexpected format for InvalidType error: got %s", err) @@ -256,7 +256,7 @@ func TestSlice(t *testing.T) { t.Errorf("invalidSlice.StringIndent(%q, %q, %+v) = %q, expected %q", testKey, testPrefix, testOutput, indented, "") } - invalid, err = newSlice([]int{}, nil) + invalid, err = newSlice([]int{}, nil, &visited{}) if invalidErr, ok := err.(errInvalidType); ok { if !strings.Contains(invalidErr.Error(), "nil") { t.Errorf("NewSlice([]int{}, nil): unexpected format for InvalidType error: got %s", err) @@ -338,7 +338,7 @@ func TestMap(t *testing.T) { Type: ContentDiffer, }, } { - m, err := newMap(test.LHS, test.RHS) + m, err := newMap(test.LHS, test.RHS, &visited{}) if err != nil { t.Errorf("NewMap(%+v, %+v): unexpected error: %q", test.LHS, test.RHS, err) @@ -353,7 +353,7 @@ func TestMap(t *testing.T) { testStrings(fmt.Sprintf("TestMap[%d]", i), t, test, ss, indented) } - invalid, err := newMap(nil, nil) + invalid, err := newMap(nil, nil, &visited{}) if invalidErr, ok := err.(errInvalidType); ok { if !strings.Contains(invalidErr.Error(), "nil") { t.Errorf("NewMap(nil, nil): unexpected format for InvalidType error: got %s", err) @@ -371,7 +371,7 @@ func TestMap(t *testing.T) { t.Errorf("invalidMap.StringIndent(%q, %q, %+v) = %q, expected %q", testKey, testPrefix, testOutput, indented, "") } - invalid, err = newMap(map[int]int{}, nil) + invalid, err = newMap(map[int]int{}, nil, &visited{}) if invalidErr, ok := err.(errInvalidType); ok { if !strings.Contains(invalidErr.Error(), "nil") { t.Errorf("NewMap(map[int]int{}, nil): unexpected format for InvalidType error: got %s", err) @@ -390,6 +390,47 @@ func TestMap(t *testing.T) { } } +func TestCircular(t *testing.T) { + first := map[int]interface{}{} + second := map[int]interface{}{ + 0: first, + } + first[0] = second + notCyclic := map[int]interface{}{ + 0: map[int]interface{}{ + 0: map[int]interface{}{ + 0: "foo", + }, + }, + } + + for _, test := range []struct { + lhs interface{} + rhs interface{} + wantError bool + }{ + {lhs: first, rhs: first, wantError: true}, + {lhs: first, rhs: second, wantError: true}, + {lhs: first, rhs: second, wantError: true}, + {lhs: first, rhs: notCyclic, wantError: true}, + {lhs: notCyclic, rhs: first, wantError: true}, + {lhs: notCyclic, rhs: notCyclic}, + } { + d, err := Diff(test.lhs, test.rhs) + + if test.wantError && (err == nil || err != ErrCyclic) { + t.Errorf("Expected error %q, got %q", ErrCyclic, err) + } + if !test.wantError && err != nil { + t.Errorf("Unexpected error %q", err) + } + + if test.wantError && d.Diff() != ContentDiffer { + t.Errorf("Expected Diff() to be %s, got %s", ContentDiffer, d.Diff()) + } + } +} + func TestIgnore(t *testing.T) { ignoreDiff, _ := Ignore() diff --git a/diff/errors.go b/diff/errors.go index 488fa3b..109f81f 100644 --- a/diff/errors.go +++ b/diff/errors.go @@ -1,6 +1,7 @@ package diff import ( + "errors" "fmt" "reflect" ) @@ -23,3 +24,6 @@ type errInvalidType struct { func (e errInvalidType) Error() string { return fmt.Sprintf("%T is not a valid type for %s", e.Value, e.For) } + +// ErrCyclic is returned when one of the compared values contain circular references +var ErrCyclic = errors.New("circular references not supported") diff --git a/diff/map.go b/diff/map.go index a0ff66d..76e2233 100644 --- a/diff/map.go +++ b/diff/map.go @@ -21,7 +21,7 @@ type mapExcess struct { value interface{} } -func newMap(lhs, rhs interface{}) (mapDiff, error) { +func newMap(lhs, rhs interface{}, visited *visited) (Differ, error) { var diffs = make(map[interface{}]Differ) lhsVal := reflect.ValueOf(lhs) @@ -41,7 +41,7 @@ func newMap(lhs, rhs interface{}) (mapDiff, error) { rhsEl := rhsVal.MapIndex(key) if lhsEl.IsValid() && rhsEl.IsValid() { - diff, err := Diff(lhsEl.Interface(), rhsEl.Interface()) + diff, err := diff(lhsEl.Interface(), rhsEl.Interface(), visited) if diff.Diff() != Identical { } diffs[key.Interface()] = diff diff --git a/diff/slice.go b/diff/slice.go index 4661534..6d29a14 100644 --- a/diff/slice.go +++ b/diff/slice.go @@ -20,7 +20,7 @@ type sliceExcess struct { value interface{} } -func newSlice(lhs, rhs interface{}) (Differ, error) { +func newSlice(lhs, rhs interface{}, visited *visited) (Differ, error) { var diffs []Differ lhsVal := reflect.ValueOf(lhs) @@ -39,7 +39,7 @@ func newSlice(lhs, rhs interface{}) (Differ, error) { for i := 0; i < nElems; i++ { if i < lhsVal.Len() && i < rhsVal.Len() { - diff, err := Diff(lhsVal.Index(i).Interface(), rhsVal.Index(i).Interface()) + diff, err := diff(lhsVal.Index(i).Interface(), rhsVal.Index(i).Interface(), visited) if diff.Diff() != Identical { } diffs = append(diffs, diff)