diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go index 9f4f26e..c80e0a2 100644 --- a/interfaces/interfaces.go +++ b/interfaces/interfaces.go @@ -11,6 +11,20 @@ type EqualFunc[A any] interface { Equal(A) bool } +// CopyFunc represents a type implementing the Copy method. +type CopyFunc[A any] interface { + Copy() A +} + +// CopyEqual represents a type satisfying both EqualFunc and CopyFunc. +type CopyEqual[T any] interface { + EqualFunc[T] + CopyFunc[T] +} + +// TweakFunc is used for modifying a value in tests. +type TweakFunc[E CopyEqual[E]] func(E) + // LessFunc represents any type implementing the Less method. type LessFunc[A any] interface { Less(A) bool diff --git a/internal/assertions/assertions.go b/internal/assertions/assertions.go index 3ded105..46d6a61 100644 --- a/internal/assertions/assertions.go +++ b/internal/assertions/assertions.go @@ -303,7 +303,6 @@ func Equal[E interfaces.EqualFunc[E]](exp, val E) (s string) { func NotEqual[E interfaces.EqualFunc[E]](exp, val E) (s string) { if val.Equal(exp) { s = "expected inequality via .Equal method\n" - s += diff(exp, val, nil) } return } @@ -1224,6 +1223,31 @@ func Wait(wc *wait.Constraint) (s string) { return } +type Tweak[E interfaces.CopyEqual[E]] struct { + Field string + Apply interfaces.TweakFunc[E] +} + +// StructEqual will apply each Tweak and assert E.Equal captures the modification. +func StructEqual[E interfaces.CopyEqual[E]](original E, tweaks []Tweak[E]) (s string) { + for _, tweak := range tweaks { + if tweak.Field == "" { + return "Tweak.Field must be set" + } else if tweak.Apply == nil { + return "Tweak.Apply must be set" + } + clone := original.Copy() + if s = Equal[E](original, clone); s != "" { + return + } + tweak.Apply(clone) + if s = NotEqual[E](original, clone); s != "" { + return + } + } + return +} + func bullet(msg string, args ...any) string { return fmt.Sprintf("↪ "+msg, args...) } diff --git a/internal/util/slices.go b/internal/util/slices.go new file mode 100644 index 0000000..4ea2563 --- /dev/null +++ b/internal/util/slices.go @@ -0,0 +1,10 @@ +package util + +// CloneSliceFunc creates a copy of A by first applying convert to each element. +func CloneSliceFunc[A, B any](original []A, convert func(item A) B) []B { + clone := make([]B, len(original)) + for i := 0; i < len(original); i++ { + clone[i] = convert(original[i]) + } + return clone +} diff --git a/internal/util/slices_test.go b/internal/util/slices_test.go new file mode 100644 index 0000000..64751ae --- /dev/null +++ b/internal/util/slices_test.go @@ -0,0 +1,36 @@ +package util + +import ( + "strconv" + "testing" +) + +func TestCloneSliceFunc(t *testing.T) { + t.Run("empty", func(t *testing.T) { + result := CloneSliceFunc([]int{}, func(i int) string { + return strconv.Itoa(i) + }) + if len(result) > 0 { + t.Fatal("expected empty slice") + } + }) + + t.Run("non empty", func(t *testing.T) { + original := []int{1, 4, 5} + result := CloneSliceFunc(original, func(i int) string { + return strconv.Itoa(i) + }) + if len(result) != 3 { + t.Fatal("expected length of 3") + } + if result[0] != "1" { + t.Fatal("expected result[0] == 1") + } + if result[1] != "4" { + t.Fatal("expected result[1] == 4") + } + if result[2] != "5" { + t.Fatal("expected result[2] == 5") + } + }) +} diff --git a/must/must.go b/must/must.go index bcf3f90..3c0ff56 100644 --- a/must/must.go +++ b/must/must.go @@ -12,6 +12,7 @@ import ( "github.com/shoenig/test/internal/assertions" "github.com/shoenig/test/internal/brokenfs" "github.com/shoenig/test/internal/constraints" + "github.com/shoenig/test/internal/util" "github.com/shoenig/test/wait" ) @@ -715,3 +716,27 @@ func Wait(t T, wc *wait.Constraint, settings ...Setting) { t.Helper() invoke(t, assertions.Wait(wc), settings...) } + +// Tweak is used to modify a struct and assert its Equal method captures the +// modification. +// +// Field is the name of the struct field and is used only for error printing. +// Apply is a function that modifies E. +type Tweak[E interfaces.CopyEqual[E]] struct { + Field string + Apply interfaces.TweakFunc[E] +} + +// StructEqual will apply each Tweak and assert E.Equal captures the modification. +func StructEqual[E interfaces.CopyEqual[E]](t T, original E, tweaks []Tweak[E], settings ...Setting) { + t.Helper() + invoke(t, assertions.StructEqual( + original, + util.CloneSliceFunc[Tweak[E], assertions.Tweak[E]]( + tweaks, + func(tweak Tweak[E]) assertions.Tweak[E] { + return assertions.Tweak[E]{Field: tweak.Field, Apply: tweak.Apply} + }, + ), + ), settings...) +} diff --git a/must/must_test.go b/must/must_test.go index c8b8943..eb5beb4 100644 --- a/must/must_test.go +++ b/must/must_test.go @@ -1477,6 +1477,32 @@ func (c *container[T]) Len() int { return c.length } +func (c *container[T]) Copy() *container[T] { + return &container[T]{ + contains: c.contains, + empty: c.empty, + size: c.size, + length: c.length, + } +} + +func (c *container[T]) Equal(o *container[T]) bool { + if c == nil || o == nil { + return c == o + } + switch { + case c.contains != o.contains: + return false + case c.empty != o.empty: + return false + case c.size != o.size: + return false + case c.length != o.length: + return false + } + return true +} + func TestEmpty(t *testing.T) { tc := newCase(t, `expected to be empty, but was not`) t.Cleanup(tc.assert) @@ -1562,3 +1588,27 @@ func TestWait_TestFunc(t *testing.T) { wait.Timeout(100*time.Millisecond), )) } + +func TestStructEqual(t *testing.T) { + tc := newCase(t, `expected inequality via .Equal method`) + t.Cleanup(tc.assert) + + StructEqual[*container[int]](tc, &container[int]{ + contains: true, + empty: true, + size: 1, + length: 2, + }, []Tweak[*container[int]]{{ + Field: "contains", + Apply: func(c *container[int]) { c.contains = false }, + }, { + Field: "empty", + Apply: func(c *container[int]) { c.empty = false }, + }, { + Field: "size", + Apply: func(c *container[int]) { c.size = 9 }, + }, { + Field: "length", + Apply: func(c *container[int]) { c.length = 2 }, // no mod + }}) +} diff --git a/test.go b/test.go index 6239c35..60aefb5 100644 --- a/test.go +++ b/test.go @@ -10,6 +10,7 @@ import ( "github.com/shoenig/test/internal/assertions" "github.com/shoenig/test/internal/brokenfs" "github.com/shoenig/test/internal/constraints" + "github.com/shoenig/test/internal/util" "github.com/shoenig/test/wait" ) @@ -713,3 +714,27 @@ func Wait(t T, wc *wait.Constraint, settings ...Setting) { t.Helper() invoke(t, assertions.Wait(wc), settings...) } + +// Tweak is used to modify a struct and assert its Equal method captures the +// modification. +// +// Field is the name of the struct field and is used only for error printing. +// Apply is a function that modifies E. +type Tweak[E interfaces.CopyEqual[E]] struct { + Field string + Apply interfaces.TweakFunc[E] +} + +// StructEqual will apply each Tweak and assert E.Equal captures the modification. +func StructEqual[E interfaces.CopyEqual[E]](t T, original E, tweaks []Tweak[E], settings ...Setting) { + t.Helper() + invoke(t, assertions.StructEqual( + original, + util.CloneSliceFunc[Tweak[E], assertions.Tweak[E]]( + tweaks, + func(tweak Tweak[E]) assertions.Tweak[E] { + return assertions.Tweak[E]{Field: tweak.Field, Apply: tweak.Apply} + }, + ), + ), settings...) +} diff --git a/test_test.go b/test_test.go index fe0628b..79c8a56 100644 --- a/test_test.go +++ b/test_test.go @@ -1475,6 +1475,32 @@ func (c *container[T]) Len() int { return c.length } +func (c *container[T]) Copy() *container[T] { + return &container[T]{ + contains: c.contains, + empty: c.empty, + size: c.size, + length: c.length, + } +} + +func (c *container[T]) Equal(o *container[T]) bool { + if c == nil || o == nil { + return c == o + } + switch { + case c.contains != o.contains: + return false + case c.empty != o.empty: + return false + case c.size != o.size: + return false + case c.length != o.length: + return false + } + return true +} + func TestEmpty(t *testing.T) { tc := newCase(t, `expected to be empty, but was not`) t.Cleanup(tc.assert) @@ -1560,3 +1586,27 @@ func TestWait_TestFunc(t *testing.T) { wait.Timeout(100*time.Millisecond), )) } + +func TestStructEqual(t *testing.T) { + tc := newCase(t, `expected inequality via .Equal method`) + t.Cleanup(tc.assert) + + StructEqual[*container[int]](tc, &container[int]{ + contains: true, + empty: true, + size: 1, + length: 2, + }, []Tweak[*container[int]]{{ + Field: "contains", + Apply: func(c *container[int]) { c.contains = false }, + }, { + Field: "empty", + Apply: func(c *container[int]) { c.empty = false }, + }, { + Field: "size", + Apply: func(c *container[int]) { c.size = 9 }, + }, { + Field: "length", + Apply: func(c *container[int]) { c.length = 2 }, // no mod + }}) +}