Skip to content

Commit

Permalink
update uint256
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon committed Feb 4, 2025
1 parent 2ec6a15 commit 40a7837
Show file tree
Hide file tree
Showing 10 changed files with 512 additions and 210 deletions.
5 changes: 5 additions & 0 deletions examples/gno.land/p/demo/uint256/arithmetic.gno
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ func (z *Uint) isBitSet(n uint) bool {
return (z.arr[n/64] & (1 << (n % 64))) != 0
}

// IsOverflow checks if the number is too large to be represented as a 256-bit unsigned integer.
func (z *Uint) IsOverflow() bool {
return z.isBitSet(255)
}

// addTo computes x += y.
// Requires len(x) >= len(y).
func addTo(x, y []uint64) uint64 {
Expand Down
146 changes: 69 additions & 77 deletions examples/gno.land/p/demo/uint256/arithmetic_test.gno
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package uint256

import (
"testing"

"gno.land/p/demo/uassert"
)

type binOp2Test struct {
Expand All @@ -25,9 +27,7 @@ func TestAdd(t *testing.T) {
want := MustFromDecimal(tt.want)
got := new(Uint).Add(x, y)

if got.Neq(want) {
t.Errorf("Add(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand All @@ -41,8 +41,8 @@ func TestAddOverflow(t *testing.T) {
{"1", "0", "1", false},
{"1", "1", "2", false},
{"10", "10", "20", false},
{"18446744073709551615", "18446744073709551615", "36893488147419103230", false}, // uint64 overflow, but not Uint256 overflow
{"115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "0", true}, // 2^256 - 1 + 1, should overflow
{"18446744073709551615", "18446744073709551615", "36893488147419103230", false}, // uint64 overflow, but not Uint256 overflow
{twoPow256Sub1, "1", "0", true},
{"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819968", "115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, // (2^255 - 1) + 2^255, no overflow
{"57896044618658097711785492504343953926634992332820282019728792003956564819967", "57896044618658097711785492504343953926634992332820282019728792003956564819969", "0", true}, // (2^255 - 1) + (2^255 + 1), should overflow
}
Expand All @@ -54,10 +54,8 @@ func TestAddOverflow(t *testing.T) {

got, overflow := new(Uint).AddOverflow(x, y)

if got.Cmp(want) != 0 || overflow != tt.overflow {
t.Errorf("AddOverflow(%s, %s) = (%s, %v), want (%s, %v)",
tt.x, tt.y, got.String(), overflow, tt.want, tt.overflow)
}
uassert.Equal(t, got.String(), want.String())
uassert.Equal(t, overflow, tt.overflow)
}
}

Expand All @@ -75,15 +73,9 @@ func TestSub(t *testing.T) {
y := MustFromDecimal(tc.y)

want := MustFromDecimal(tc.want)

got := new(Uint).Sub(x, y)

if got.Neq(want) {
t.Errorf(
"Sub(%s, %s) = %v, want %v",
tc.x, tc.y, got.String(), want.String(),
)
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand All @@ -109,12 +101,8 @@ func TestSubOverflow(t *testing.T) {

got, overflow := new(Uint).SubOverflow(x, y)

if got.Cmp(want) != 0 || overflow != tc.overflow {
t.Errorf(
"SubOverflow(%s, %s) = (%s, %v), want (%s, %v)",
tc.x, tc.y, got.String(), overflow, tc.want, tc.overflow,
)
}
uassert.Equal(t, got.String(), want.String())
uassert.Equal(t, overflow, tc.overflow)
}
}

Expand All @@ -132,9 +120,7 @@ func TestMul(t *testing.T) {
want := MustFromDecimal(tt.want)
got := new(Uint).Mul(x, y)

if got.Neq(want) {
t.Errorf("Mul(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand Down Expand Up @@ -162,15 +148,8 @@ func TestMulOverflow(t *testing.T) {

gotZ, gotOver := new(Uint).MulOverflow(x, y)

if gotZ.Neq(wantZ) {
t.Errorf(
"MulOverflow(%s, %s) = %s, want %s",
tt.x, tt.y, gotZ.String(), wantZ.String(),
)
}
if gotOver != tt.wantOver {
t.Errorf("MulOverflow(%s, %s) = %v, want %v", tt.x, tt.y, gotOver, tt.wantOver)
}
uassert.Equal(t, gotZ.String(), wantZ.String())
uassert.Equal(t, gotOver, tt.wantOver)
}
}

Expand All @@ -187,13 +166,11 @@ func TestDiv(t *testing.T) {
for _, tt := range tests {
x := MustFromDecimal(tt.x)
y := MustFromDecimal(tt.y)
want := MustFromDecimal(tt.want)

want := MustFromDecimal(tt.want)
got := new(Uint).Div(x, y)

if got.Neq(want) {
t.Errorf("Div(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand All @@ -212,13 +189,11 @@ func TestMod(t *testing.T) {
for _, tt := range tests {
x := MustFromDecimal(tt.x)
y := MustFromDecimal(tt.y)
want := MustFromDecimal(tt.want)

want := MustFromDecimal(tt.want)
got := new(Uint).Mod(x, y)

if got.Neq(want) {
t.Errorf("Mod(%s, %s) = %v, want %v", tt.x, tt.y, got.String(), want.String())
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand All @@ -245,16 +220,11 @@ func TestMulMod(t *testing.T) {
x := MustFromHex(tt.x)
y := MustFromHex(tt.y)
m := MustFromHex(tt.m)
want := MustFromHex(tt.want)

want := MustFromHex(tt.want)
got := new(Uint).MulMod(x, y, m)

if got.Neq(want) {
t.Errorf(
"MulMod(%s, %s, %s) = %s, want %s",
tt.x, tt.y, tt.m, got.String(), want.String(),
)
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand Down Expand Up @@ -284,17 +254,12 @@ func TestDivMod(t *testing.T) {
gotMod := new(Uint)
gotDiv.DivMod(x, y, gotMod)

for i := range gotDiv.arr {
if gotDiv.arr[i] != wantDiv.arr[i] {
t.Errorf("DivMod(%s, %s) got Div %v, want Div %v", tt.x, tt.y, gotDiv, wantDiv)
break
}
}
for i := range gotMod.arr {
if gotMod.arr[i] != wantMod.arr[i] {
t.Errorf("DivMod(%s, %s) got Mod %v, want Mod %v", tt.x, tt.y, gotMod, wantMod)
break
}
uassert.Equal(t, gotMod.arr[i], wantMod.arr[i])
}

for i := range gotDiv.arr {
uassert.Equal(t, gotDiv.arr[i], wantDiv.arr[i])
}
}
}
Expand All @@ -314,12 +279,9 @@ func TestNeg(t *testing.T) {
for _, tt := range tests {
x := MustFromDecimal(tt.x)
want := MustFromDecimal(tt.want)

got := new(Uint).Neg(x)

if got.Neq(want) {
t.Errorf("Neg(%s) = %v, want %v", tt.x, got.String(), want.String())
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand All @@ -339,16 +301,11 @@ func TestExp(t *testing.T) {
for _, tt := range tests {
x := MustFromDecimal(tt.x)
y := MustFromDecimal(tt.y)
want := MustFromDecimal(tt.want)

want := MustFromDecimal(tt.want)
got := new(Uint).Exp(x, y)

if got.Neq(want) {
t.Errorf(
"Exp(%s, %s) = %v, want %v",
tt.x, tt.y, got.String(), want.String(),
)
}
uassert.Equal(t, got.String(), want.String())
}
}

Expand Down Expand Up @@ -378,15 +335,50 @@ func TestExp_LargeExponent(t *testing.T) {
base := MustFromDecimal(tt.base)
exponent := MustFromDecimal(tt.exponent)
expected := MustFromDecimal(tt.expected)

result := new(Uint).Exp(base, exponent)

if result.Neq(expected) {
t.Errorf(
"Test %s failed. Expected %s, got %s",
tt.name, expected.String(), result.String(),
)
}
uassert.Equal(t, result.String(), expected.String())
})
}
}

func TestIsOverflow(t *testing.T) {
tests := []struct {
name string
input *Uint
expected bool
}{
{
name: "Number greater than max value",
input: &Uint{arr: [4]uint64{
^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0),
}},
expected: true,
},
{
name: "Max value",
input: &Uint{arr: [4]uint64{
^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0) >> 1,
}},
expected: false,
},
{
name: "0",
input: &Uint{arr: [4]uint64{0, 0, 0, 0}},
expected: false,
},
{
name: "Only 255th bit set",
input: &Uint{arr: [4]uint64{
0, 0, 0, uint64(1) << 63,
}},
expected: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uassert.Equal(t, tt.input.IsOverflow(), tt.expected)
})
}
}
42 changes: 39 additions & 3 deletions examples/gno.land/p/demo/uint256/bits_table.gno
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Code generated by go run make_tables.go. DO NOT EDIT.

package uint256

// ntz8tab: A lookup table for 8-bit values (0-255) that shows
// the number of trailing zeros (zeros from the rightmost/LSB position).
//
// Example) 0x28 (binary 00101000)
//
// Binary: [ 0 0 1 0 1 0 0 0 ]
// ^^^^^^^^
// 3 consecutive zeros
//
// ntz8tab[0x28] = 3
const ntz8tab = "" +
"\x08\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
Expand All @@ -24,6 +31,15 @@ const ntz8tab = "" +
"\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00"

// pop8tab: A lookup table for 8-bit values (0-255) that allows
// quick lookup of the number of set bits (bits set to 1).
//
// Example) 0xB2 (binary 10110010)
//
// Binary: [ 1 0 1 1 0 0 1 0 ]
// Total of 4 set bits
//
// pop8tab[0xB2] = 4
const pop8tab = "" +
"\x00\x01\x01\x02\x01\x02\x02\x03\x01\x02\x02\x03\x02\x03\x03\x04" +
"\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" +
Expand All @@ -42,6 +58,15 @@ const pop8tab = "" +
"\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" +
"\x04\x05\x05\x06\x05\x06\x06\x07\x05\x06\x06\x07\x06\x07\x07\x08"

// rev8tab: A lookup table that pre-calculates bit-reversed results
// for 8-bit values (0-255).
//
// Example) 0x16 (binary 00010110)
//
// Binary: [ 0 0 0 1 0 1 1 0 ]
// Reversed: [ 0 1 1 0 1 0 0 0 ] -> 0x68 (104 in decimal)
//
// rev8tab[0x16] = 0x68
const rev8tab = "" +
"\x00\x80\x40\xc0\x20\xa0\x60\xe0\x10\x90\x50\xd0\x30\xb0\x70\xf0" +
"\x08\x88\x48\xc8\x28\xa8\x68\xe8\x18\x98\x58\xd8\x38\xb8\x78\xf8" +
Expand All @@ -60,6 +85,17 @@ const rev8tab = "" +
"\x07\x87\x47\xc7\x27\xa7\x67\xe7\x17\x97\x57\xd7\x37\xb7\x77\xf7" +
"\x0f\x8f\x4f\xcf\x2f\xaf\x6f\xef\x1f\x9f\x5f\xdf\x3f\xbf\x7f\xff"

// len8tab: A lookup table that pre-calculates the "bit length"
// of 8-bit values (0-255).
// (Bit length: position of the most significant bit + 1)
//
// Examples)
//
// 0x00 (binary 00000000) → No MSB → length 0
// 0x01 (binary 00000001) → MSB at rightmost position → length 1
// 0x02 (binary 00000010) ~ 0x03 (00000011) → length 2
// 0x04 (binary 00000100) ~ 0x07 (00000111) → length 3
// ...
const len8tab = "" +
"\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" +
"\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" +
Expand Down
Loading

0 comments on commit 40a7837

Please sign in to comment.