From 9c574cbde61433fbdf9bad168b5ddd78984bf1d8 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Sat, 5 Feb 2022 00:01:35 -0800 Subject: [PATCH] Make assert package more type-safe Use generics to make the assert package more type-safe (and drop some unfortunate reflection). Asserting that two values of different types are equal is now a compile-time type error rather than a head-scratcher of a diff! --- internal/assert/assert.go | 29 +++++++++-------------------- internal/assert/assert_test.go | 3 ++- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 724398a9..a9b927f5 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -73,7 +73,7 @@ func Diff() Option { } // Equal asserts that two values are equal. -func Equal(t testing.TB, got, want any, msg string, opts ...Option) bool { +func Equal[T any](t testing.TB, got, want T, msg string, opts ...Option) bool { t.Helper() params := newParams(got, want, msg, opts...) if cmp.Equal(got, want, params.cmpOpts...) { @@ -84,7 +84,7 @@ func Equal(t testing.TB, got, want any, msg string, opts ...Option) bool { } // NotEqual asserts that two values aren't equal. -func NotEqual(t testing.TB, got, want any, msg string, opts ...Option) bool { +func NotEqual[T any](t testing.TB, got, want T, msg string, opts ...Option) bool { t.Helper() params := newParams(got, want, msg, opts...) if !cmp.Equal(got, want, params.cmpOpts...) { @@ -117,12 +117,9 @@ func NotNil(t testing.TB, got any, msg string, opts ...Option) bool { } // Zero asserts that the value is its type's zero value. -func Zero(t testing.TB, got any, msg string, opts ...Option) bool { +func Zero[T any](t testing.TB, got T, msg string, opts ...Option) bool { t.Helper() - if got == nil { - return true - } - want := makeZero(got) + var want T params := newParams(got, want, msg, opts...) if cmp.Equal(got, want, params.cmpOpts...) { return true @@ -132,16 +129,13 @@ func Zero(t testing.TB, got any, msg string, opts ...Option) bool { } // NotZero asserts that the value is non-zero. -func NotZero(t testing.TB, got any, msg string, opts ...Option) bool { +func NotZero[T any](t testing.TB, got T, msg string, opts ...Option) bool { t.Helper() - if got != nil { - want := makeZero(got) - params := newParams(got, want, msg, opts...) - if !cmp.Equal(got, want, params.cmpOpts...) { - return true - } + var want T + params := newParams(got, want, msg, opts...) + if !cmp.Equal(got, want, params.cmpOpts...) { + return true } - params := newParams(got, nil, msg, opts...) report(t, params, fmt.Sprintf("assert.NotZero (type %T)", got), false /* showWant */) return false } @@ -243,8 +237,3 @@ func isNil(got any) bool { return false } } - -func makeZero(i any) any { - typ := reflect.TypeOf(i) - return reflect.Zero(typ).Interface() -} diff --git a/internal/assert/assert_test.go b/internal/assert/assert_test.go index 7a5b3cfb..06a6fb3e 100644 --- a/internal/assert/assert_test.go +++ b/internal/assert/assert_test.go @@ -38,7 +38,8 @@ func TestAssertions(t *testing.T) { }) t.Run("zero", func(t *testing.T) { - Zero(t, nil, "") + var n *int + Zero(t, n, "") var pair Pair Zero(t, pair, "") var null *Pair