diff --git a/examples/gno.land/p/demo/uint256/arithmetic.gno b/examples/gno.land/p/demo/uint256/arithmetic.gno index c3e2ed83738..92afe3031ca 100644 --- a/examples/gno.land/p/demo/uint256/arithmetic.gno +++ b/examples/gno.land/p/demo/uint256/arithmetic.gno @@ -416,6 +416,11 @@ func (z *Uint) isBitSet(n uint) bool { return (z.arr[n/64] & (1 << (n % 64))) != 0 } +// IsOverflow checks if the number is too large to be represented as a 256-bit unsigned integer. +func (z *Uint) IsOverflow() bool { + return z.isBitSet(255) +} + // addTo computes x += y. // Requires len(x) >= len(y). func addTo(x, y []uint64) uint64 { diff --git a/examples/gno.land/p/demo/uint256/arithmetic_test.gno b/examples/gno.land/p/demo/uint256/arithmetic_test.gno index addd33db997..dcd7b471e8e 100644 --- a/examples/gno.land/p/demo/uint256/arithmetic_test.gno +++ b/examples/gno.land/p/demo/uint256/arithmetic_test.gno @@ -2,6 +2,8 @@ package uint256 import ( "testing" + + "gno.land/p/demo/uassert" ) type binOp2Test struct { @@ -25,9 +27,7 @@ func TestAdd(t *testing.T) { want := MustFromDecimal(tt.want) got := new(Uint).Add(x, y) - if got.Neq(want) { - t.Errorf("Add(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -41,8 +41,8 @@ func TestAddOverflow(t *testing.T) { {"1", "0", "1", false}, {"1", "1", "2", false}, {"10", "10", "20", false}, - {"18446744073709551615", "18446744073709551615", "36893488147419103230", false}, // uint64 overflow, but not Uint256 overflow - {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "0", true}, // 2^256 - 1 + 1, should overflow + {"18446744073709551615", "18446744073709551615", "36893488147419103230", false}, // uint64 overflow, but not Uint256 overflow + {twoPow256Sub1, "1", "0", true}, {"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819968", "115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, // (2^255 - 1) + 2^255, no overflow {"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819969", "0", true}, // (2^255 - 1) + (2^255 + 1), should overflow } @@ -54,10 +54,8 @@ func TestAddOverflow(t *testing.T) { got, overflow := new(Uint).AddOverflow(x, y) - if got.Cmp(want) != 0 || overflow != tt.overflow { - t.Errorf("AddOverflow(%s, %s) = (%s, %v), want (%s, %v)", - tt.x, tt.y, got.String(), overflow, tt.want, tt.overflow) - } + uassert.Equal(t, got.String(), want.String()) + uassert.Equal(t, overflow, tt.overflow) } } @@ -75,15 +73,9 @@ func TestSub(t *testing.T) { y := MustFromDecimal(tc.y) want := MustFromDecimal(tc.want) - got := new(Uint).Sub(x, y) - if got.Neq(want) { - t.Errorf( - "Sub(%s, %s) = %v, want %v", - tc.x, tc.y, got.String(), want.String(), - ) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -109,12 +101,8 @@ func TestSubOverflow(t *testing.T) { got, overflow := new(Uint).SubOverflow(x, y) - if got.Cmp(want) != 0 || overflow != tc.overflow { - t.Errorf( - "SubOverflow(%s, %s) = (%s, %v), want (%s, %v)", - tc.x, tc.y, got.String(), overflow, tc.want, tc.overflow, - ) - } + uassert.Equal(t, got.String(), want.String()) + uassert.Equal(t, overflow, tc.overflow) } } @@ -132,9 +120,7 @@ func TestMul(t *testing.T) { want := MustFromDecimal(tt.want) got := new(Uint).Mul(x, y) - if got.Neq(want) { - t.Errorf("Mul(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -162,15 +148,8 @@ func TestMulOverflow(t *testing.T) { gotZ, gotOver := new(Uint).MulOverflow(x, y) - if gotZ.Neq(wantZ) { - t.Errorf( - "MulOverflow(%s, %s) = %s, want %s", - tt.x, tt.y, gotZ.String(), wantZ.String(), - ) - } - if gotOver != tt.wantOver { - t.Errorf("MulOverflow(%s, %s) = %v, want %v", tt.x, tt.y, gotOver, tt.wantOver) - } + uassert.Equal(t, gotZ.String(), wantZ.String()) + uassert.Equal(t, gotOver, tt.wantOver) } } @@ -187,13 +166,11 @@ func TestDiv(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) y := MustFromDecimal(tt.y) - want := MustFromDecimal(tt.want) + want := MustFromDecimal(tt.want) got := new(Uint).Div(x, y) - if got.Neq(want) { - t.Errorf("Div(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -212,13 +189,11 @@ func TestMod(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) y := MustFromDecimal(tt.y) - want := MustFromDecimal(tt.want) + want := MustFromDecimal(tt.want) got := new(Uint).Mod(x, y) - if got.Neq(want) { - t.Errorf("Mod(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -245,16 +220,11 @@ func TestMulMod(t *testing.T) { x := MustFromHex(tt.x) y := MustFromHex(tt.y) m := MustFromHex(tt.m) - want := MustFromHex(tt.want) + want := MustFromHex(tt.want) got := new(Uint).MulMod(x, y, m) - if got.Neq(want) { - t.Errorf( - "MulMod(%s, %s, %s) = %s, want %s", - tt.x, tt.y, tt.m, got.String(), want.String(), - ) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -284,17 +254,12 @@ func TestDivMod(t *testing.T) { gotMod := new(Uint) gotDiv.DivMod(x, y, gotMod) - for i := range gotDiv.arr { - if gotDiv.arr[i] != wantDiv.arr[i] { - t.Errorf("DivMod(%s, %s) got Div %v, want Div %v", tt.x, tt.y, gotDiv, wantDiv) - break - } - } for i := range gotMod.arr { - if gotMod.arr[i] != wantMod.arr[i] { - t.Errorf("DivMod(%s, %s) got Mod %v, want Mod %v", tt.x, tt.y, gotMod, wantMod) - break - } + uassert.Equal(t, gotMod.arr[i], wantMod.arr[i]) + } + + for i := range gotDiv.arr { + uassert.Equal(t, gotDiv.arr[i], wantDiv.arr[i]) } } } @@ -314,12 +279,9 @@ func TestNeg(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) want := MustFromDecimal(tt.want) - got := new(Uint).Neg(x) - if got.Neq(want) { - t.Errorf("Neg(%s) = %v, want %v", tt.x, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -339,16 +301,11 @@ func TestExp(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) y := MustFromDecimal(tt.y) - want := MustFromDecimal(tt.want) + want := MustFromDecimal(tt.want) got := new(Uint).Exp(x, y) - if got.Neq(want) { - t.Errorf( - "Exp(%s, %s) = %v, want %v", - tt.x, tt.y, got.String(), want.String(), - ) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -378,15 +335,50 @@ func TestExp_LargeExponent(t *testing.T) { base := MustFromDecimal(tt.base) exponent := MustFromDecimal(tt.exponent) expected := MustFromDecimal(tt.expected) - result := new(Uint).Exp(base, exponent) - if result.Neq(expected) { - t.Errorf( - "Test %s failed. Expected %s, got %s", - tt.name, expected.String(), result.String(), - ) - } + uassert.Equal(t, result.String(), expected.String()) + }) + } +} + +func TestIsOverflow(t *testing.T) { + tests := []struct { + name string + input *Uint + expected bool + }{ + { + name: "Number greater than max value", + input: &Uint{arr: [4]uint64{ + ^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0), + }}, + expected: true, + }, + { + name: "Max value", + input: &Uint{arr: [4]uint64{ + ^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0) >> 1, + }}, + expected: false, + }, + { + name: "0", + input: &Uint{arr: [4]uint64{0, 0, 0, 0}}, + expected: false, + }, + { + name: "Only 255th bit set", + input: &Uint{arr: [4]uint64{ + 0, 0, 0, uint64(1) << 63, + }}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.input.IsOverflow(), tt.expected) }) } } diff --git a/examples/gno.land/p/demo/uint256/bits_table.gno b/examples/gno.land/p/demo/uint256/bits_table.gno index 53dbea94827..aaaef7ebaa2 100644 --- a/examples/gno.land/p/demo/uint256/bits_table.gno +++ b/examples/gno.land/p/demo/uint256/bits_table.gno @@ -1,11 +1,18 @@ // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. - -// Code generated by go run make_tables.go. DO NOT EDIT. - package uint256 +// ntz8tab: A lookup table for 8-bit values (0-255) that shows +// the number of trailing zeros (zeros from the rightmost/LSB position). +// +// Example) 0x28 (binary 00101000) +// +// Binary: [ 0 0 1 0 1 0 0 0 ] +// ^^^^^^^^ +// 3 consecutive zeros +// +// ntz8tab[0x28] = 3 const ntz8tab = "" + "\x08\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + @@ -24,6 +31,15 @@ const ntz8tab = "" + "\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +// pop8tab: A lookup table for 8-bit values (0-255) that allows +// quick lookup of the number of set bits (bits set to 1). +// +// Example) 0xB2 (binary 10110010) +// +// Binary: [ 1 0 1 1 0 0 1 0 ] +// Total of 4 set bits +// +// pop8tab[0xB2] = 4 const pop8tab = "" + "\x00\x01\x01\x02\x01\x02\x02\x03\x01\x02\x02\x03\x02\x03\x03\x04" + "\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" + @@ -42,6 +58,15 @@ const pop8tab = "" + "\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" + "\x04\x05\x05\x06\x05\x06\x06\x07\x05\x06\x06\x07\x06\x07\x07\x08" +// rev8tab: A lookup table that pre-calculates bit-reversed results +// for 8-bit values (0-255). +// +// Example) 0x16 (binary 00010110) +// +// Binary: [ 0 0 0 1 0 1 1 0 ] +// Reversed: [ 0 1 1 0 1 0 0 0 ] -> 0x68 (104 in decimal) +// +// rev8tab[0x16] = 0x68 const rev8tab = "" + "\x00\x80\x40\xc0\x20\xa0\x60\xe0\x10\x90\x50\xd0\x30\xb0\x70\xf0" + "\x08\x88\x48\xc8\x28\xa8\x68\xe8\x18\x98\x58\xd8\x38\xb8\x78\xf8" + @@ -60,6 +85,17 @@ const rev8tab = "" + "\x07\x87\x47\xc7\x27\xa7\x67\xe7\x17\x97\x57\xd7\x37\xb7\x77\xf7" + "\x0f\x8f\x4f\xcf\x2f\xaf\x6f\xef\x1f\x9f\x5f\xdf\x3f\xbf\x7f\xff" +// len8tab: A lookup table that pre-calculates the "bit length" +// of 8-bit values (0-255). +// (Bit length: position of the most significant bit + 1) +// +// Examples) +// +// 0x00 (binary 00000000) → No MSB → length 0 +// 0x01 (binary 00000001) → MSB at rightmost position → length 1 +// 0x02 (binary 00000010) ~ 0x03 (00000011) → length 2 +// 0x04 (binary 00000100) ~ 0x07 (00000111) → length 3 +// ... const len8tab = "" + "\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" + "\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" + diff --git a/examples/gno.land/p/demo/uint256/bitwise_test.gno b/examples/gno.land/p/demo/uint256/bitwise_test.gno index 45118af0b0f..b40a24d18f7 100644 --- a/examples/gno.land/p/demo/uint256/bitwise_test.gno +++ b/examples/gno.land/p/demo/uint256/bitwise_test.gno @@ -1,6 +1,10 @@ package uint256 -import "testing" +import ( + "testing" + + "gno.land/p/demo/uassert" +) type logicOpTest struct { name string @@ -40,12 +44,7 @@ func TestOr(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := new(Uint).Or(&tt.x, &tt.y) - if *res != tt.want { - t.Errorf( - "Or(%s, %s) = %s, want %s", - tt.x.String(), tt.y.String(), res.String(), (tt.want).String(), - ) - } + uassert.Equal(t, res.String(), tt.want.String()) }) } } @@ -99,12 +98,7 @@ func TestAnd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := new(Uint).And(&tt.x, &tt.y) - if *res != tt.want { - t.Errorf( - "And(%s, %s) = %s, want %s", - tt.x.String(), tt.y.String(), res.String(), (tt.want).String(), - ) - } + uassert.Equal(t, res.String(), tt.want.String()) }) } } @@ -135,12 +129,7 @@ func TestNot(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := new(Uint).Not(&tt.x) - if *res != tt.want { - t.Errorf( - "Not(%s) = %s, want %s", - tt.x.String(), res.String(), (tt.want).String(), - ) - } + uassert.Equal(t, res.String(), tt.want.String()) }) } } @@ -194,12 +183,7 @@ func TestAndNot(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := new(Uint).AndNot(&tt.x, &tt.y) - if *res != tt.want { - t.Errorf( - "AndNot(%s, %s) = %s, want %s", - tt.x.String(), tt.y.String(), res.String(), (tt.want).String(), - ) - } + uassert.Equal(t, res.String(), tt.want.String()) }) } } @@ -253,12 +237,7 @@ func TestXor(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := new(Uint).Xor(&tt.x, &tt.y) - if *res != tt.want { - t.Errorf( - "Xor(%s, %s) = %s, want %s", - tt.x.String(), tt.y.String(), res.String(), (tt.want).String(), - ) - } + uassert.Equal(t, res.String(), tt.want.String()) }) } } @@ -306,13 +285,11 @@ func TestLsh(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) - want := MustFromDecimal(tt.want) + want := MustFromDecimal(tt.want) got := new(Uint).Lsh(x, tt.y) - if got.Neq(want) { - t.Errorf("Lsh(%s, %d) = %s, want %s", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -356,9 +333,7 @@ func TestRsh(t *testing.T) { want := MustFromDecimal(tt.want) got := new(Uint).Rsh(x, tt.y) - if got.Neq(want) { - t.Errorf("Rsh(%s, %d) = %s, want %s", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } @@ -413,11 +388,8 @@ func TestSRsh(t *testing.T) { for _, tt := range tests { x := MustFromHex(tt.x) want := MustFromHex(tt.want) - got := new(Uint).SRsh(x, tt.y) - if !got.Eq(want) { - t.Errorf("SRsh(%s, %d) = %s, want %s", tt.x, tt.y, got.String(), want.String()) - } + uassert.Equal(t, got.String(), want.String()) } } diff --git a/examples/gno.land/p/demo/uint256/cmp_test.gno b/examples/gno.land/p/demo/uint256/cmp_test.gno index 05243290271..2c57efe6cb2 100644 --- a/examples/gno.land/p/demo/uint256/cmp_test.gno +++ b/examples/gno.land/p/demo/uint256/cmp_test.gno @@ -3,6 +3,8 @@ package uint256 import ( "strings" "testing" + + "gno.land/p/demo/uassert" ) func TestSign(t *testing.T) { @@ -31,9 +33,7 @@ func TestSign(t *testing.T) { for _, tt := range tests { t.Run(tt.input.String(), func(t *testing.T) { result := tt.input.Sign() - if result != tt.expected { - t.Errorf("Sign() = %d; want %d", result, tt.expected) - } + uassert.Equal(t, result, tt.expected) }) } } @@ -56,10 +56,7 @@ func TestCmp(t *testing.T) { x := MustFromDecimal(tc.x) y := MustFromDecimal(tc.y) - got := x.Cmp(y) - if got != tc.want { - t.Errorf("Cmp(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) - } + uassert.Equal(t, x.Cmp(y), tc.want) } } @@ -75,11 +72,7 @@ func TestIsZero(t *testing.T) { for _, tt := range tests { x := MustFromDecimal(tt.x) - - got := x.IsZero() - if got != tt.want { - t.Errorf("IsZero(%s) = %v, want %v", tt.x, got, tt.want) - } + uassert.Equal(t, x.IsZero(), tt.want) } } @@ -98,11 +91,7 @@ func TestLtUint64(t *testing.T) { for _, tc := range tests { x := parseTestString(t, tc.x) - - got := x.LtUint64(tc.y) - if got != tc.want { - t.Errorf("LtUint64(%s, %d) = %v, want %v", tc.x, tc.y, got, tc.want) - } + uassert.Equal(t, x.LtUint64(tc.y), tc.want) } } @@ -136,10 +125,7 @@ func TestUint_GtUint64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { z := MustFromDecimal(tt.z) - - if got := z.GtUint64(tt.n); got != tt.want { - t.Errorf("Uint.GtUint64() = %v, want %v", got, tt.want) - } + uassert.Equal(t, z.GtUint64(tt.n), tt.want) }) } } @@ -147,17 +133,11 @@ func TestUint_GtUint64(t *testing.T) { func TestSGT(t *testing.T) { x := MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") y := MustFromHex("0x0") - actual := x.Sgt(y) - if actual { - t.Fatalf("Expected %v false", actual) - } + uassert.False(t, x.Sgt(y)) x = MustFromHex("0x0") y = MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") - actual = x.Sgt(y) - if !actual { - t.Fatalf("Expected %v true", actual) - } + uassert.True(t, x.Sgt(y)) } func TestEq(t *testing.T) { @@ -181,11 +161,7 @@ func TestEq(t *testing.T) { continue } - got := x.Eq(y) - - if got != tt.want { - t.Errorf("Eq(%s, %s) = %v, want %v", tt.x, tt.y, got, tt.want) - } + uassert.Equal(t, x.Eq(y), tt.want) } } @@ -202,18 +178,12 @@ func TestUint_Lte(t *testing.T) { for _, tt := range tests { z, err := FromDecimal(tt.z) - if err != nil { - t.Error(err) - continue - } + uassert.NoError(t, err) + x, err := FromDecimal(tt.x) - if err != nil { - t.Error(err) - continue - } - if got := z.Lte(x); got != tt.want { - t.Errorf("Uint.Lte(%v, %v) = %v, want %v", tt.z, tt.x, got, tt.want) - } + uassert.NoError(t, err) + + uassert.Equal(t, z.Lte(x), tt.want) } } @@ -232,13 +202,12 @@ func TestUint_Gte(t *testing.T) { z := parseTestString(t, tt.z) x := parseTestString(t, tt.x) - if got := z.Gte(x); got != tt.want { - t.Errorf("Uint.Gte(%v, %v) = %v, want %v", tt.z, tt.x, got, tt.want) - } + uassert.Equal(t, z.Gte(x), tt.want) } } -func parseTestString(_ *testing.T, s string) *Uint { +func parseTestString(t *testing.T, s string) *Uint { + t.Helper() var x *Uint if strings.HasPrefix(s, "0x") { diff --git a/examples/gno.land/p/demo/uint256/conversion_test.gno b/examples/gno.land/p/demo/uint256/conversion_test.gno index 3942a102511..73e0029c10b 100644 --- a/examples/gno.land/p/demo/uint256/conversion_test.gno +++ b/examples/gno.land/p/demo/uint256/conversion_test.gno @@ -1,6 +1,10 @@ package uint256 -import "testing" +import ( + "testing" + + "gno.land/p/demo/uassert" +) func TestIsUint64(t *testing.T) { tests := []struct { @@ -16,11 +20,7 @@ func TestIsUint64(t *testing.T) { for _, tt := range tests { x := MustFromHex(tt.x) - got := x.IsUint64() - - if got != tt.want { - t.Errorf("IsUint64(%s) = %v, want %v", tt.x, got, tt.want) - } + uassert.Equal(t, x.IsUint64(), tt.want) } } @@ -50,9 +50,7 @@ func TestDec(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.z.Dec() - if result != tt.want { - t.Errorf("Dec(%v) = %s, want %s", tt.z, result, tt.want) - } + uassert.Equal(t, result, tt.want) }) } } @@ -107,16 +105,10 @@ func TestUint_Scan(t *testing.T) { err := z.Scan(tt.input) if tt.wantErr { - if err == nil { - t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) - } + uassert.Error(t, err) } else { - if err != nil { - t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr) - } - if !z.Eq(tt.want) { - t.Errorf("Scan() = %v, want %v", z, tt.want) - } + uassert.NoError(t, err) + uassert.Equal(t, z.String(), tt.want.String()) } }) } @@ -168,8 +160,6 @@ func TestSetBytes(t *testing.T) { z := new(Uint) z.SetBytes(test.input) expected := MustFromDecimal(test.expected) - if z.Cmp(expected) != 0 { - t.Errorf("SetBytes(%x) = %s, expected %s", test.input, z.String(), test.expected) - } + uassert.Equal(t, z.String(), expected.String()) } } diff --git a/examples/gno.land/p/demo/uint256/error.gno b/examples/gno.land/p/demo/uint256/error.gno index d200bb9cc8f..a4cf93ca480 100644 --- a/examples/gno.land/p/demo/uint256/error.gno +++ b/examples/gno.land/p/demo/uint256/error.gno @@ -16,6 +16,7 @@ var ( ErrBadEncodedLength = errors.New("bad ssz encoded length") ErrInvalidBase = errors.New("invalid base") ErrInvalidBitSize = errors.New("invalid bit size") + ErrDivisionByZero = errors.New("division by zero") ) type u256Error struct { diff --git a/examples/gno.land/p/demo/uint256/fullmath.gno b/examples/gno.land/p/demo/uint256/fullmath.gno new file mode 100644 index 00000000000..98ecf2809aa --- /dev/null +++ b/examples/gno.land/p/demo/uint256/fullmath.gno @@ -0,0 +1,149 @@ +package uint256 + +import ( + "gno.land/p/demo/ufmt" +) + +// MulDiv calculates floor(a * b / denominator) with full precision by +// performing a 512×256-bit division. It returns an error if the denominator is +// zero or if the intermediate results exceed the 256-bit range. +func MulDiv(a, b, denominator *Uint) (*Uint, error) { + prod0 := Zero() + prod1 := Zero() + + { + // mm is the modulo multiplication result using a mask of all 1s. + mm := new(Uint).MulMod(a, b, new(Uint).Not(Zero())) + prod0 = new(Uint).Mul(a, b) + + ltBool := mm.Lt(prod0) + ltUint := Zero() + if ltBool { + ltUint = One() + } + prod1 = new(Uint).Sub(new(Uint).Sub(mm, prod0), ltUint) + } + + // Handle non-overflow cases, 256 by 256 division. + if prod1.IsZero() { + if !denominator.Gt(Zero()) { + return nil, ErrDivisionByZero + } + result := new(Uint).Div(prod0, denominator) + return result, nil + } + + // Ensure the result is less than 2**256. + // Also prevents denominator == 0. + if !denominator.Gt(prod1) { + return nil, ufmt.Errorf("denominator (%s) must be greater than prod1 (%s)", + denominator.String(), prod1.String()) + } + + /////////////////////////////////////////////// + // 512 by 256 division. + /////////////////////////////////////////////// + + // Make division exact by subtracting the remainder from [prod1 prod0]. + remainder := new(Uint).MulMod(a, b, denominator) + + // Subtract 256-bit number from 512-bit number. + gtBool := remainder.Gt(prod0) + gtUint := Zero() + if gtBool { + gtUint = One() + } + prod1 = new(Uint).Sub(prod1, gtUint) + prod0 = new(Uint).Sub(prod0, remainder) + + // Factor powers of two out of denominator. + // Compute largest power of two divisor of denominator. Always >= 1. + twos := new(Uint).And(new(Uint).Neg(denominator), denominator) + + // Divide denominator by the power of two. + denomDiv := new(Uint).Div(denominator, twos) + denominator = denomDiv + + // Divide [prod1 prod0] by the factors of two. + prod0 = new(Uint).Div(prod0, twos) + + // Shift bits from prod1 into prod0. + // For this we need to flip `twos` such that it is 2**256 / twos. + // If twos is zero, then it becomes one. + twos = new(Uint).Add( + new(Uint).Div( + new(Uint).Sub(Zero(), twos), + twos, + ), + One(), + ) + prod0 = new(Uint).Or(prod0, new(Uint).Mul(prod1, twos)) + + // Invert denominator mod 2**256. + // Since denominator is now odd, it has an inverse modulo 2**256 + // such that denominator * inv = 1 mod 2**256. + // Compute the inverse using a seed that is correct for four bits. + inv := new(Uint).Mul(NewUint(3), denominator) + inv = new(Uint).Xor(inv, NewUint(2)) + + // Use Newton-Raphson iteration to improve the precision. + // Each iteration doubles the correct bits. + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**8 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**16 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**32 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**64 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**128 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**256 + + // Because the division is now exact, we can obtain the final result + // by multiplying prod0 with the modular inverse of denominator. + result := new(Uint).Mul(prod0, inv) + return result, nil +} + +// MulDivRoundingUp calculates ceil(a * b / denominator) with full precision. +// If a * b is not an exact multiple of the denominator, the result is incremented by one. +// It returns an error if denominator is zero or if the intermediate results exceed the 256-bit range. +func MulDivRoundingUp(a, b, denominator *Uint) (*Uint, error) { + result, err := MulDiv(a, b, denominator) + if err != nil { + return nil, err + } + + remainder := new(Uint).MulMod(a, b, denominator) + if remainder.Gt(Zero()) { + // Here we ensure that result + 1 does not overflow 256 bits. + maxUint, err := FromDecimal(twoPow256Sub1) + if err != nil { + return nil, err + } + if !result.Lt(maxUint) { + return nil, err + } + result = new(Uint).Add(result, One()) + } + + return result, nil +} + +// DivRoundingUp performs division of x by y and rounds up the result. +// It returns an error if y is zero. +func DivRoundingUp(x, y *Uint) (*Uint, error) { + if y.IsZero() { + return nil, ErrDivisionByZero + } + + div := new(Uint).Div(x, y) + mod := new(Uint).Mod(x, y) + + // Add one if there is a remainder. + result := new(Uint).Add(div, gt(mod, Zero())) + return result, nil +} + +func gt(x, y *Uint) *Uint { + if x.Gt(y) { + return One() + } + return Zero() +} diff --git a/examples/gno.land/p/demo/uint256/fullmath_test.gno b/examples/gno.land/p/demo/uint256/fullmath_test.gno new file mode 100644 index 00000000000..848eaed440f --- /dev/null +++ b/examples/gno.land/p/demo/uint256/fullmath_test.gno @@ -0,0 +1,199 @@ +package uint256 + +import ( + "testing" + + "gno.land/p/demo/uassert" +) + +var Q128 *Uint + +func init() { + Q128 = MustFromDecimal("340282366920938463463374607431768211456") // 2**128 +} + +func TestMulDiv(t *testing.T) { + tests := []struct { + name string + a *Uint + b *Uint + denominator *Uint + want *Uint + wantErr bool + }{ + { + name: "simple multiplication and division", + a: NewUint(100), + b: NewUint(200), + denominator: NewUint(50), + want: NewUint(400), + wantErr: false, + }, + { + name: "division by zero", + a: NewUint(100), + b: NewUint(200), + denominator: Zero(), + want: nil, + wantErr: true, + }, + { + name: "zero numerator", + a: Zero(), + b: NewUint(200), + denominator: NewUint(50), + want: Zero(), + wantErr: false, + }, + { + name: "large numbers within bounds", + a: Q128, + b: NewUint(50), + denominator: NewUint(100), + want: new(Uint).Div(new(Uint).Mul(Q128, NewUint(50)), NewUint(100)), + wantErr: false, + }, + { + name: "max uint256 values", + a: MustFromDecimal(twoPow256Sub1), + b: One(), + denominator: One(), + want: MustFromDecimal(twoPow256Sub1), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MulDiv(tt.a, tt.b, tt.denominator) + if err != nil { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + uassert.Equal(t, got.String(), tt.want.String()) + } + }) + } +} + +func TestMulDivRoundingUp(t *testing.T) { + tests := []struct { + name string + a *Uint + b *Uint + denominator *Uint + want *Uint + wantErr bool + }{ + { + name: "exact division", + a: NewUint(100), + b: NewUint(200), + denominator: NewUint(50), + want: NewUint(400), + wantErr: false, + }, + { + name: "division with rounding up", + a: NewUint(101), + b: NewUint(200), + denominator: NewUint(50), + want: NewUint(404), + wantErr: false, + }, + { + name: "division by zero", + a: NewUint(100), + b: NewUint(200), + denominator: Zero(), + want: nil, + wantErr: true, + }, + { + name: "zero numerator", + a: Zero(), + b: NewUint(200), + denominator: NewUint(50), + want: Zero(), + wantErr: false, + }, + { + name: "large numbers with remainder", + a: Q128, + b: NewUint(51), + denominator: NewUint(100), + want: new(Uint).Add(new(Uint).Div(new(Uint).Mul(Q128, NewUint(51)), NewUint(100)), One()), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MulDivRoundingUp(tt.a, tt.b, tt.denominator) + if err != nil { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + uassert.Equal(t, got.String(), tt.want.String()) + } + }) + } +} + +func TestDivRoundingUp(t *testing.T) { + tests := []struct { + name string + x *Uint + y *Uint + want *Uint + wantErr bool + }{ + { + name: "division by zero", + x: NewUint(5), + y: Zero(), + want: nil, + wantErr: true, + }, + { + name: "simple division without remainder", + x: NewUint(10), + y: NewUint(2), + want: NewUint(5), + wantErr: false, + }, + { + name: "division with remainder should round up", + x: NewUint(11), + y: NewUint(3), + want: NewUint(4), + wantErr: false, + }, + { + name: "division of zero by non-zero", + x: Zero(), + y: NewUint(5), + want: Zero(), + wantErr: false, + }, + { + name: "division of max uint256 by one", + x: MustFromDecimal(twoPow256Sub1), + y: One(), + want: MustFromDecimal(twoPow256Sub1), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DivRoundingUp(tt.x, tt.y) + if err != nil { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + uassert.Equal(t, got.String(), tt.want.String()) + } + }) + } +} diff --git a/examples/gno.land/p/demo/uint256/uint256_test.gno b/examples/gno.land/p/demo/uint256/uint256_test.gno index ae8129b6e27..f92629b3be3 100644 --- a/examples/gno.land/p/demo/uint256/uint256_test.gno +++ b/examples/gno.land/p/demo/uint256/uint256_test.gno @@ -2,14 +2,14 @@ package uint256 import ( "testing" + + "gno.land/p/demo/uassert" ) func TestSetAllOne(t *testing.T) { z := Zero() z.SetAllOne() - if z.String() != twoPow256Sub1 { - t.Errorf("Expected all ones, got %s", z.String()) - } + uassert.Equal(t, z.String(), twoPow256Sub1) } func TestByte(t *testing.T) { @@ -44,9 +44,7 @@ func TestByte(t *testing.T) { n := NewUint(32) result := z.Byte(n) - if !result.IsZero() { - t.Errorf("Expected zero for position >= 32, got %v", result) - } + uassert.Equal(t, result.IsZero(), true) } func TestBitLen(t *testing.T) { @@ -71,10 +69,7 @@ func TestBitLen(t *testing.T) { z, _ := FromHex(tt.input) result := z.BitLen() - if result != tt.expected { - t.Errorf("Test case %d failed. Input: %s, Expected: %d, Got: %d", - i, tt.input, tt.expected, result) - } + uassert.Equal(t, result, tt.expected) } } @@ -99,11 +94,7 @@ func TestByteLen(t *testing.T) { for i, tt := range tests { z, _ := FromHex(tt.input) result := z.ByteLen() - - if result != tt.expected { - t.Errorf("Test case %d failed. Input: %s, Expected: %d, Got: %d", - i, tt.input, tt.expected, result) - } + uassert.Equal(t, result, tt.expected) } } @@ -120,8 +111,6 @@ func TestClone(t *testing.T) { for _, tt := range tests { z, _ := FromHex(tt.input) result := z.Clone() - if result.String() != tt.expected { - t.Errorf("Test %s failed. Expected %s, got %s", tt.input, tt.expected, result.String()) - } + uassert.Equal(t, result.String(), tt.expected) } }