Skip to content

Commit

Permalink
Merge pull request #2 from yazgazan/support-circular-references
Browse files Browse the repository at this point in the history
Adding support for circular references
  • Loading branch information
yazgazan authored May 3, 2017
2 parents f5a7c2a + 9a867e5 commit 34459d8
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 17 deletions.
1 change: 0 additions & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ $ jaydiff --report --show-types --ignore-excess old.json new.json

# Ideas

- Handle circular references properly
- JayPatch
- Have the diff lib support more types (Structs, interfaces (?), Arrays, ...)
- JSON-ish output (instead of go-ish)
71 changes: 65 additions & 6 deletions diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,47 @@ type Differ interface {
//
// BUG(yazgazan): An infinite recursion is possible if the lhs and/or rhs objects are cyclic
func Diff(lhs, rhs interface{}) (Differ, error) {
return diff(lhs, rhs, &visited{})
}

func diff(lhs, rhs interface{}, visited *visited) (Differ, error) {
lhsVal := reflect.ValueOf(lhs)
rhsVal := reflect.ValueOf(rhs)

if lhs == nil && rhs == nil {
return scalar{lhs, rhs}, nil
if d, ok := nilCheck(lhs, rhs); ok {
return d, nil
}
if lhs == nil || rhs == nil {
return types{lhs, rhs}, nil
if err := visited.add(lhsVal, rhsVal); err != nil {
return types{lhs, rhs}, ErrCyclic
}
if lhsVal.Type().Comparable() && rhsVal.Type().Comparable() {
return scalar{lhs, rhs}, nil
}
if lhsVal.Kind() != rhsVal.Kind() {
return types{lhs, rhs}, nil
}

if lhsVal.Kind() == reflect.Slice {
return newSlice(lhs, rhs)
return newSlice(lhs, rhs, visited)
}
if lhsVal.Kind() == reflect.Map {
return newMap(lhs, rhs)
return newMap(lhs, rhs, visited)
}

return types{lhs, rhs}, &ErrUnsupported{lhsVal.Type(), rhsVal.Type()}
}

func nilCheck(lhs, rhs interface{}) (Differ, bool) {
if lhs == nil && rhs == nil {
return scalar{lhs, rhs}, true
}
if lhs == nil || rhs == nil {
return types{lhs, rhs}, true
}

return nil, false
}

func (t Type) String() string {
switch t {
case Identical:
Expand Down Expand Up @@ -95,3 +111,46 @@ func IsMissing(d Differ) bool {
return true
}
}

type visited struct {
LHS []uintptr
RHS []uintptr
}

func (v *visited) add(lhs, rhs reflect.Value) error {
if canAddr(lhs) {
if inPointers(v.LHS, lhs) {
return ErrCyclic
}
v.LHS = append(v.LHS, lhs.Pointer())
}
if canAddr(rhs) {
if inPointers(v.RHS, rhs) {
return ErrCyclic
}
v.RHS = append(v.RHS, rhs.Pointer())
}

return nil
}

func inPointers(pointers []uintptr, val reflect.Value) bool {
for _, lhs := range pointers {
if lhs == val.Pointer() {
return true
}
}

return false
}

func canAddr(val reflect.Value) bool {
switch val.Kind() {
case reflect.Chan, reflect.Func, reflect.Map:
fallthrough
case reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return true
}

return false
}
53 changes: 47 additions & 6 deletions diff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func TestSlice(t *testing.T) {
Type: ContentDiffer,
},
} {
typ, err := newSlice(test.LHS, test.RHS)
typ, err := newSlice(test.LHS, test.RHS, &visited{})

if err != nil {
t.Errorf("NewSlice(%+v, %+v): unexpected error: %q", test.LHS, test.RHS, err)
Expand All @@ -238,7 +238,7 @@ func TestSlice(t *testing.T) {
testStrings("TestSlice", t, test, ss, indented)
}

invalid, err := newSlice(nil, nil)
invalid, err := newSlice(nil, nil, &visited{})
if invalidErr, ok := err.(errInvalidType); ok {
if !strings.Contains(invalidErr.Error(), "nil") {
t.Errorf("NewSlice(nil, nil): unexpected format for InvalidType error: got %s", err)
Expand All @@ -256,7 +256,7 @@ func TestSlice(t *testing.T) {
t.Errorf("invalidSlice.StringIndent(%q, %q, %+v) = %q, expected %q", testKey, testPrefix, testOutput, indented, "")
}

invalid, err = newSlice([]int{}, nil)
invalid, err = newSlice([]int{}, nil, &visited{})
if invalidErr, ok := err.(errInvalidType); ok {
if !strings.Contains(invalidErr.Error(), "nil") {
t.Errorf("NewSlice([]int{}, nil): unexpected format for InvalidType error: got %s", err)
Expand Down Expand Up @@ -338,7 +338,7 @@ func TestMap(t *testing.T) {
Type: ContentDiffer,
},
} {
m, err := newMap(test.LHS, test.RHS)
m, err := newMap(test.LHS, test.RHS, &visited{})

if err != nil {
t.Errorf("NewMap(%+v, %+v): unexpected error: %q", test.LHS, test.RHS, err)
Expand All @@ -353,7 +353,7 @@ func TestMap(t *testing.T) {
testStrings(fmt.Sprintf("TestMap[%d]", i), t, test, ss, indented)
}

invalid, err := newMap(nil, nil)
invalid, err := newMap(nil, nil, &visited{})
if invalidErr, ok := err.(errInvalidType); ok {
if !strings.Contains(invalidErr.Error(), "nil") {
t.Errorf("NewMap(nil, nil): unexpected format for InvalidType error: got %s", err)
Expand All @@ -371,7 +371,7 @@ func TestMap(t *testing.T) {
t.Errorf("invalidMap.StringIndent(%q, %q, %+v) = %q, expected %q", testKey, testPrefix, testOutput, indented, "")
}

invalid, err = newMap(map[int]int{}, nil)
invalid, err = newMap(map[int]int{}, nil, &visited{})
if invalidErr, ok := err.(errInvalidType); ok {
if !strings.Contains(invalidErr.Error(), "nil") {
t.Errorf("NewMap(map[int]int{}, nil): unexpected format for InvalidType error: got %s", err)
Expand All @@ -390,6 +390,47 @@ func TestMap(t *testing.T) {
}
}

func TestCircular(t *testing.T) {
first := map[int]interface{}{}
second := map[int]interface{}{
0: first,
}
first[0] = second
notCyclic := map[int]interface{}{
0: map[int]interface{}{
0: map[int]interface{}{
0: "foo",
},
},
}

for _, test := range []struct {
lhs interface{}
rhs interface{}
wantError bool
}{
{lhs: first, rhs: first, wantError: true},
{lhs: first, rhs: second, wantError: true},
{lhs: first, rhs: second, wantError: true},
{lhs: first, rhs: notCyclic, wantError: true},
{lhs: notCyclic, rhs: first, wantError: true},
{lhs: notCyclic, rhs: notCyclic},
} {
d, err := Diff(test.lhs, test.rhs)

if test.wantError && (err == nil || err != ErrCyclic) {
t.Errorf("Expected error %q, got %q", ErrCyclic, err)
}
if !test.wantError && err != nil {
t.Errorf("Unexpected error %q", err)
}

if test.wantError && d.Diff() != ContentDiffer {
t.Errorf("Expected Diff() to be %s, got %s", ContentDiffer, d.Diff())
}
}
}

func TestIgnore(t *testing.T) {
ignoreDiff, _ := Ignore()

Expand Down
4 changes: 4 additions & 0 deletions diff/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package diff

import (
"errors"
"fmt"
"reflect"
)
Expand All @@ -23,3 +24,6 @@ type errInvalidType struct {
func (e errInvalidType) Error() string {
return fmt.Sprintf("%T is not a valid type for %s", e.Value, e.For)
}

// ErrCyclic is returned when one of the compared values contain circular references
var ErrCyclic = errors.New("circular references not supported")
4 changes: 2 additions & 2 deletions diff/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type mapExcess struct {
value interface{}
}

func newMap(lhs, rhs interface{}) (mapDiff, error) {
func newMap(lhs, rhs interface{}, visited *visited) (Differ, error) {
var diffs = make(map[interface{}]Differ)

lhsVal := reflect.ValueOf(lhs)
Expand All @@ -41,7 +41,7 @@ func newMap(lhs, rhs interface{}) (mapDiff, error) {
rhsEl := rhsVal.MapIndex(key)

if lhsEl.IsValid() && rhsEl.IsValid() {
diff, err := Diff(lhsEl.Interface(), rhsEl.Interface())
diff, err := diff(lhsEl.Interface(), rhsEl.Interface(), visited)
if diff.Diff() != Identical {
}
diffs[key.Interface()] = diff
Expand Down
4 changes: 2 additions & 2 deletions diff/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type sliceExcess struct {
value interface{}
}

func newSlice(lhs, rhs interface{}) (Differ, error) {
func newSlice(lhs, rhs interface{}, visited *visited) (Differ, error) {
var diffs []Differ

lhsVal := reflect.ValueOf(lhs)
Expand All @@ -39,7 +39,7 @@ func newSlice(lhs, rhs interface{}) (Differ, error) {

for i := 0; i < nElems; i++ {
if i < lhsVal.Len() && i < rhsVal.Len() {
diff, err := Diff(lhsVal.Index(i).Interface(), rhsVal.Index(i).Interface())
diff, err := diff(lhsVal.Index(i).Interface(), rhsVal.Index(i).Interface(), visited)
if diff.Diff() != Identical {
}
diffs = append(diffs, diff)
Expand Down

0 comments on commit 34459d8

Please sign in to comment.