diff --git a/libs/math/safemath.go b/libs/math/safemath.go index 9afb409b2..54f0b7a4f 100644 --- a/libs/math/safemath.go +++ b/libs/math/safemath.go @@ -8,9 +8,11 @@ import ( var ErrOverflowInt64 = errors.New("int64 overflow") var ErrOverflowInt32 = errors.New("int32 overflow") +var ErrOverflowUint64 = errors.New("uint64 overflow") var ErrOverflowUint32 = errors.New("uint32 overflow") var ErrOverflowUint8 = errors.New("uint8 overflow") var ErrOverflowInt8 = errors.New("int8 overflow") +var ErrOverflow = errors.New("integer overflow") // SafeAddClipInt64 adds two int64 integers and clips the result to the int64 range. func SafeAddClipInt64(a, b int64) int64 { @@ -94,10 +96,106 @@ func SafeConvertUint32[T Integer](a T) (uint32, error) { return uint32(a), nil } +// SafeConvertUint64 takes a int and checks if it overflows. +func SafeConvertUint64[T Integer](a T) (uint64, error) { + return SafeConvert[T, uint64](a) +} + +// SafeConvertInt64 takes a int and checks if it overflows. +func SafeConvertInt64[T Integer](a T) (int64, error) { + return SafeConvert[T, int64](a) +} + +// SafeConvertInt16 takes a int and checks if it overflows. +func SafeConvertInt16[T Integer](a T) (int16, error) { + return SafeConvert[T, int16](a) +} + +// SafeConvertUint16 takes a int and checks if it overflows. +func SafeConvertUint16[T Integer](a T) (uint16, error) { + return SafeConvert[T, uint16](a) +} + type Integer interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 } +// SafeConvert converts a value of type T to a value of type U. +// It returns an error if the conversion would cause an overflow. +func SafeConvert[T Integer, U Integer](from T) (U, error) { + const uintIsSmall = math.MaxUint < math.MaxUint64 + const intIsSmall = math.MaxInt < math.MaxInt64 && math.MinInt > math.MinInt64 + + // special case for int64 and uint64 inputs; all other types are safe to convert to int64 + switch any(from).(type) { + case int64: + // conversion from int64 to uint64 - we need to check for negative values + if _, ok := any(U(0)).(uint64); ok && from < 0 { + return 0, ErrOverflow + } + return U(from), nil + case uint64: + // conversion from uint64 to int64 - we need to check for overflow + if _, ok := any(U(0)).(int64); ok && uint64(from) > math.MaxInt64 { + return 0, ErrOverflow + } + return U(from), nil + case int: + if !intIsSmall { + return SafeConvert[int64, U](int64(from)) + } + // no return here - it's safe to use normal logic + case uint: + if !uintIsSmall { + return SafeConvert[uint64, U](uint64(from)) + } + // no return here - it's safe to use normal logic + } + if uint64(from) > Max[U]() { + return 0, ErrOverflow + } + if int64(from) < Min[U]() { + return 0, ErrOverflow + } + return U(from), nil +} + +func MustConvert[FROM Integer, TO Integer](a FROM) TO { + i, err := SafeConvert[FROM, TO](a) + if err != nil { + panic(fmt.Errorf("cannot convert %d to %T: %w", a, any(i), err)) + } + return i +} + +func MustConvertUint64[T Integer](a T) uint64 { + return MustConvert[T, uint64](a) +} + +func MustConvertInt64[T Integer](a T) int64 { + return MustConvert[T, int64](a) +} + +func MustConvertUint16[T Integer](a T) uint16 { + return MustConvert[T, uint16](a) +} + +func MustConvertInt16[T Integer](a T) int16 { + return MustConvert[T, int16](a) +} + +func MustConvertUint8[T Integer](a T) uint8 { + return MustConvert[T, uint8](a) +} + +func MustConvertUint[T Integer](a T) uint { + return MustConvert[T, uint](a) +} + +func MustConvertInt[T Integer](a T) int { + return MustConvert[T, int](a) +} + // MustConvertInt32 takes an Integer and converts it to int32. // Panics if the conversion overflows. func MustConvertInt32[T Integer](a T) int32 { @@ -159,3 +257,52 @@ func SafeMulInt64(a, b int64) (int64, bool) { return a * b, false } + +// Max returns the maximum value for a type T. +func Max[T Integer]() uint64 { + var max T + switch any(max).(type) { + case int: + return uint64(math.MaxInt) + case int8: + return uint64(math.MaxInt8) + case int16: + return uint64(math.MaxInt16) + case int32: + return uint64(math.MaxInt32) + case int64: + return uint64(math.MaxInt64) + case uint: + return uint64(math.MaxUint) + case uint8: + return uint64(math.MaxUint8) + case uint16: + return uint64(math.MaxUint16) + case uint32: + return uint64(math.MaxUint32) + case uint64: + return uint64(math.MaxUint64) + default: + panic("unsupported type") + } +} + +// Min returns the minimum value for a type T. +func Min[T Integer]() int64 { + switch any(T(0)).(type) { + case int: + return int64(math.MinInt) + case int8: + return int64(math.MinInt8) + case int16: + return int64(math.MinInt16) + case int32: + return int64(math.MinInt32) + case int64: + return math.MinInt64 + case uint, uint8, uint16, uint32, uint64: + return 0 + default: + panic("unsupported type") + } +} diff --git a/libs/math/safemath_test.go b/libs/math/safemath_test.go index 92a8f3211..98a57e116 100644 --- a/libs/math/safemath_test.go +++ b/libs/math/safemath_test.go @@ -84,3 +84,108 @@ func TestSafeMul(t *testing.T) { assert.Equal(t, tc.overflow, overflow, "#%d", i) } } + +func TestSafeConvert(t *testing.T) { + testCases := []struct { + from interface{} + want interface{} + err bool + }{ + {int(0), int64(0), false}, + {int(math.MaxInt), int64(math.MaxInt), false}, + {int(math.MinInt), int64(math.MinInt), false}, + {uint(0), uint64(0), false}, + {uint(math.MaxUint), uint64(math.MaxUint), false}, + {int64(0), uint64(0), false}, + {int64(math.MaxInt64), uint64(math.MaxInt64), false}, + {int64(math.MinInt64), uint64(0), true}, + {uint64(math.MaxUint64), int64(0), true}, + {uint64(math.MaxInt64), int64(math.MaxInt64), false}, + {int32(-1), uint32(0), true}, + {int32(0), uint32(0), false}, + {int32(math.MaxInt32), uint32(math.MaxInt32), false}, + {int32(math.MaxInt32), int16(0), true}, + {int32(math.MinInt32), int16(0), true}, + {int32(0), int16(0), false}, + {uint32(math.MaxUint32), int32(0), true}, + {uint32(math.MaxInt32), int32(math.MaxInt32), false}, + {uint32(0), int32(0), false}, + {int16(0), uint32(0), false}, + {int16(-1), uint32(0), true}, + {int16(math.MaxInt16), uint32(math.MaxInt16), false}, + } + + for i, tc := range testCases { + var result interface{} + var err error + + switch from := tc.from.(type) { + case int: + switch tc.want.(type) { + case int64: + result, err = SafeConvert[int, int64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint: + switch tc.want.(type) { + case uint64: + result, err = SafeConvert[uint, uint64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int64: + switch tc.want.(type) { + case uint64: + result, err = SafeConvert[int64, uint64](from) + case int64: + result, err = SafeConvert[int64, int64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint64: + switch tc.want.(type) { + case int64: + result, err = SafeConvert[uint64, int64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int32: + switch tc.want.(type) { + case int16: + result, err = SafeConvert[int32, int16](from) + case uint32: + result, err = SafeConvert[int32, uint32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint32: + switch tc.want.(type) { + case int16: + result, err = SafeConvert[uint32, int16](from) + case int32: + result, err = SafeConvert[uint32, int32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int16: + switch tc.want.(type) { + case int32: + result, err = SafeConvert[int16, int32](from) + case uint32: + result, err = SafeConvert[int16, uint32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + default: + t.Fatalf("test case %d: unsupported source type %T", i, tc.from) + } + + if (err != nil) != tc.err { + t.Errorf("test case %d: expected error %v, got %v", i, tc.err, err) + } + if err == nil && result != tc.want { + t.Errorf("test case %d: expected result %v, got %v", i, tc.want, result) + } + } +}