Skip to content

Commit

Permalink
chore: improve safemath lib
Browse files Browse the repository at this point in the history
  • Loading branch information
lklimek committed Jan 20, 2025
1 parent dabacf2 commit a326226
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 0 deletions.
147 changes: 147 additions & 0 deletions libs/math/safemath.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (

var ErrOverflowInt64 = errors.New("int64 overflow")
var ErrOverflowInt32 = errors.New("int32 overflow")
var ErrOverflowUint64 = errors.New("uint64 overflow")
var ErrOverflowUint32 = errors.New("uint32 overflow")
var ErrOverflowUint8 = errors.New("uint8 overflow")
var ErrOverflowInt8 = errors.New("int8 overflow")
var ErrOverflow = errors.New("integer overflow")

// SafeAddClipInt64 adds two int64 integers and clips the result to the int64 range.
func SafeAddClipInt64(a, b int64) int64 {
Expand Down Expand Up @@ -94,10 +96,106 @@ func SafeConvertUint32[T Integer](a T) (uint32, error) {
return uint32(a), nil
}

// SafeConvertUint64 takes a int and checks if it overflows.
func SafeConvertUint64[T Integer](a T) (uint64, error) {
return SafeConvert[T, uint64](a)
}

// SafeConvertInt64 takes a int and checks if it overflows.
func SafeConvertInt64[T Integer](a T) (int64, error) {
return SafeConvert[T, int64](a)
}

// SafeConvertInt16 takes a int and checks if it overflows.
func SafeConvertInt16[T Integer](a T) (int16, error) {
return SafeConvert[T, int16](a)
}

// SafeConvertUint16 takes a int and checks if it overflows.
func SafeConvertUint16[T Integer](a T) (uint16, error) {
return SafeConvert[T, uint16](a)
}

type Integer interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}

// SafeConvert converts a value of type T to a value of type U.
// It returns an error if the conversion would cause an overflow.
func SafeConvert[T Integer, U Integer](from T) (U, error) {
const uintIsSmall = math.MaxUint < math.MaxUint64
const intIsSmall = math.MaxInt < math.MaxInt64 && math.MinInt > math.MinInt64

// special case for int64 and uint64 inputs; all other types are safe to convert to int64
switch any(from).(type) {
case int64:
// conversion from int64 to uint64 - we need to check for negative values
if _, ok := any(U(0)).(uint64); ok && from < 0 {
return 0, ErrOverflow
}
return U(from), nil
case uint64:
// conversion from uint64 to int64 - we need to check for overflow
if _, ok := any(U(0)).(int64); ok && uint64(from) > math.MaxInt64 {
return 0, ErrOverflow
}
return U(from), nil
case int:
if !intIsSmall {
return SafeConvert[int64, U](int64(from))
}
// no return here - it's safe to use normal logic
case uint:
if !uintIsSmall {
return SafeConvert[uint64, U](uint64(from))
}
// no return here - it's safe to use normal logic
}
if uint64(from) > Max[U]() {
return 0, ErrOverflow
}
if int64(from) < Min[U]() {
return 0, ErrOverflow
}
return U(from), nil
}

func MustConvert[FROM Integer, TO Integer](a FROM) TO {
i, err := SafeConvert[FROM, TO](a)
if err != nil {
panic(fmt.Errorf("cannot convert %d to %T: %w", a, any(i), err))
}
return i
}

func MustConvertUint64[T Integer](a T) uint64 {
return MustConvert[T, uint64](a)
}

func MustConvertInt64[T Integer](a T) int64 {
return MustConvert[T, int64](a)
}

func MustConvertUint16[T Integer](a T) uint16 {
return MustConvert[T, uint16](a)
}

func MustConvertInt16[T Integer](a T) int16 {
return MustConvert[T, int16](a)
}

func MustConvertUint8[T Integer](a T) uint8 {
return MustConvert[T, uint8](a)
}

func MustConvertUint[T Integer](a T) uint {
return MustConvert[T, uint](a)
}

func MustConvertInt[T Integer](a T) int {
return MustConvert[T, int](a)
}

// MustConvertInt32 takes an Integer and converts it to int32.
// Panics if the conversion overflows.
func MustConvertInt32[T Integer](a T) int32 {
Expand Down Expand Up @@ -159,3 +257,52 @@ func SafeMulInt64(a, b int64) (int64, bool) {

return a * b, false
}

// Max returns the maximum value for a type T.
func Max[T Integer]() uint64 {
var max T
switch any(max).(type) {
case int:
return uint64(math.MaxInt)
case int8:
return uint64(math.MaxInt8)
case int16:
return uint64(math.MaxInt16)
case int32:
return uint64(math.MaxInt32)
case int64:
return uint64(math.MaxInt64)
case uint:
return uint64(math.MaxUint)
case uint8:
return uint64(math.MaxUint8)
case uint16:
return uint64(math.MaxUint16)
case uint32:
return uint64(math.MaxUint32)
case uint64:
return uint64(math.MaxUint64)
default:
panic("unsupported type")
}
}

// Min returns the minimum value for a type T.
func Min[T Integer]() int64 {
switch any(T(0)).(type) {
case int:
return int64(math.MinInt)
case int8:
return int64(math.MinInt8)
case int16:
return int64(math.MinInt16)
case int32:
return int64(math.MinInt32)
case int64:
return math.MinInt64
case uint, uint8, uint16, uint32, uint64:
return 0
default:
panic("unsupported type")
}
}
105 changes: 105 additions & 0 deletions libs/math/safemath_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,108 @@ func TestSafeMul(t *testing.T) {
assert.Equal(t, tc.overflow, overflow, "#%d", i)
}
}

func TestSafeConvert(t *testing.T) {
testCases := []struct {
from interface{}
want interface{}
err bool
}{
{int(0), int64(0), false},
{int(math.MaxInt), int64(math.MaxInt), false},
{int(math.MinInt), int64(math.MinInt), false},
{uint(0), uint64(0), false},
{uint(math.MaxUint), uint64(math.MaxUint), false},
{int64(0), uint64(0), false},
{int64(math.MaxInt64), uint64(math.MaxInt64), false},
{int64(math.MinInt64), uint64(0), true},
{uint64(math.MaxUint64), int64(0), true},
{uint64(math.MaxInt64), int64(math.MaxInt64), false},
{int32(-1), uint32(0), true},
{int32(0), uint32(0), false},
{int32(math.MaxInt32), uint32(math.MaxInt32), false},
{int32(math.MaxInt32), int16(0), true},
{int32(math.MinInt32), int16(0), true},
{int32(0), int16(0), false},
{uint32(math.MaxUint32), int32(0), true},
{uint32(math.MaxInt32), int32(math.MaxInt32), false},
{uint32(0), int32(0), false},
{int16(0), uint32(0), false},
{int16(-1), uint32(0), true},
{int16(math.MaxInt16), uint32(math.MaxInt16), false},
}

for i, tc := range testCases {
var result interface{}
var err error

switch from := tc.from.(type) {
case int:
switch tc.want.(type) {
case int64:
result, err = SafeConvert[int, int64](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case uint:
switch tc.want.(type) {
case uint64:
result, err = SafeConvert[uint, uint64](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case int64:
switch tc.want.(type) {
case uint64:
result, err = SafeConvert[int64, uint64](from)
case int64:
result, err = SafeConvert[int64, int64](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case uint64:
switch tc.want.(type) {
case int64:
result, err = SafeConvert[uint64, int64](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case int32:
switch tc.want.(type) {
case int16:
result, err = SafeConvert[int32, int16](from)
case uint32:
result, err = SafeConvert[int32, uint32](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case uint32:
switch tc.want.(type) {
case int16:
result, err = SafeConvert[uint32, int16](from)
case int32:
result, err = SafeConvert[uint32, int32](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
case int16:
switch tc.want.(type) {
case int32:
result, err = SafeConvert[int16, int32](from)
case uint32:
result, err = SafeConvert[int16, uint32](from)
default:
t.Fatalf("test case %d: unsupported target type %T", i, tc.want)
}
default:
t.Fatalf("test case %d: unsupported source type %T", i, tc.from)
}

if (err != nil) != tc.err {
t.Errorf("test case %d: expected error %v, got %v", i, tc.err, err)
}
if err == nil && result != tc.want {
t.Errorf("test case %d: expected result %v, got %v", i, tc.want, result)
}
}
}

0 comments on commit a326226

Please sign in to comment.