Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Make assert package more type-safe
Browse files Browse the repository at this point in the history
Use generics to make the assert package more type-safe (and drop some
unfortunate reflection). Asserting that two values of different types
are equal is now a compile-time type error rather than a head-scratcher
of a diff!
  • Loading branch information
akshayjshah committed Feb 28, 2022
1 parent 2eb6ffa commit 9c574cb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
29 changes: 9 additions & 20 deletions internal/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Diff() Option {
}

// Equal asserts that two values are equal.
func Equal(t testing.TB, got, want any, msg string, opts ...Option) bool {
func Equal[T any](t testing.TB, got, want T, msg string, opts ...Option) bool {
t.Helper()
params := newParams(got, want, msg, opts...)
if cmp.Equal(got, want, params.cmpOpts...) {
Expand All @@ -84,7 +84,7 @@ func Equal(t testing.TB, got, want any, msg string, opts ...Option) bool {
}

// NotEqual asserts that two values aren't equal.
func NotEqual(t testing.TB, got, want any, msg string, opts ...Option) bool {
func NotEqual[T any](t testing.TB, got, want T, msg string, opts ...Option) bool {
t.Helper()
params := newParams(got, want, msg, opts...)
if !cmp.Equal(got, want, params.cmpOpts...) {
Expand Down Expand Up @@ -117,12 +117,9 @@ func NotNil(t testing.TB, got any, msg string, opts ...Option) bool {
}

// Zero asserts that the value is its type's zero value.
func Zero(t testing.TB, got any, msg string, opts ...Option) bool {
func Zero[T any](t testing.TB, got T, msg string, opts ...Option) bool {
t.Helper()
if got == nil {
return true
}
want := makeZero(got)
var want T
params := newParams(got, want, msg, opts...)
if cmp.Equal(got, want, params.cmpOpts...) {
return true
Expand All @@ -132,16 +129,13 @@ func Zero(t testing.TB, got any, msg string, opts ...Option) bool {
}

// NotZero asserts that the value is non-zero.
func NotZero(t testing.TB, got any, msg string, opts ...Option) bool {
func NotZero[T any](t testing.TB, got T, msg string, opts ...Option) bool {
t.Helper()
if got != nil {
want := makeZero(got)
params := newParams(got, want, msg, opts...)
if !cmp.Equal(got, want, params.cmpOpts...) {
return true
}
var want T
params := newParams(got, want, msg, opts...)
if !cmp.Equal(got, want, params.cmpOpts...) {
return true
}
params := newParams(got, nil, msg, opts...)
report(t, params, fmt.Sprintf("assert.NotZero (type %T)", got), false /* showWant */)
return false
}
Expand Down Expand Up @@ -243,8 +237,3 @@ func isNil(got any) bool {
return false
}
}

func makeZero(i any) any {
typ := reflect.TypeOf(i)
return reflect.Zero(typ).Interface()
}
3 changes: 2 additions & 1 deletion internal/assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func TestAssertions(t *testing.T) {
})

t.Run("zero", func(t *testing.T) {
Zero(t, nil, "")
var n *int
Zero(t, n, "")
var pair Pair
Zero(t, pair, "")
var null *Pair
Expand Down

0 comments on commit 9c574cb

Please sign in to comment.