Skip to content

Commit

Permalink
fix(devnet-sdk): ensure balances are comparable (#14267)
Browse files Browse the repository at this point in the history
In error cases, make sure we can still compare resulting balances
without crashing.

In particular, this makes interop_smoke_test skip properly if it can't
find a funded wallet to use.
  • Loading branch information
sigma authored Feb 10, 2025
1 parent 6b849f6 commit d2ca378
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 43 deletions.
5 changes: 3 additions & 2 deletions devnet-sdk/system/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package system
import (
"context"
"fmt"
"math/big"
"strings"

"github.com/ethereum-optimism/optimism/devnet-sdk/types"
Expand Down Expand Up @@ -52,12 +53,12 @@ func (w *wallet) SendETH(to types.Address, amount types.Balance) types.WriteInvo
func (w *wallet) Balance() types.Balance {
client, err := w.chain.getClient()
if err != nil {
return types.Balance{}
return types.NewBalance(new(big.Int))
}

balance, err := client.BalanceAt(context.Background(), w.address, nil)
if err != nil {
return types.Balance{}
return types.NewBalance(new(big.Int))
}

return types.NewBalance(balance)
Expand Down
92 changes: 92 additions & 0 deletions devnet-sdk/system/wallet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package system

import (
"context"
"math/big"
"testing"

"github.com/ethereum-optimism/optimism/devnet-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// testWallet is a minimal wallet implementation for testing balance functionality
type testWallet struct {
privateKey types.Key
address types.Address
chain *mockChainForBalance // Use concrete type to access mock client directly
}

func (w *testWallet) Balance() types.Balance {
// Use the mock client directly instead of going through getClient()
balance, err := w.chain.client.BalanceAt(context.Background(), w.address, nil)
if err != nil {
return types.NewBalance(new(big.Int))
}

return types.NewBalance(balance)
}

// mockEthClient implements a mock ethereum client for testing
type mockEthClient struct {
mock.Mock
}

func (m *mockEthClient) BalanceAt(ctx context.Context, account types.Address, blockNumber *big.Int) (*big.Int, error) {
args := m.Called(ctx, account, blockNumber)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*big.Int), args.Error(1)
}

// mockChainForBalance implements just enough of the chain interface for balance testing
type mockChainForBalance struct {
mock.Mock
client *mockEthClient
}

func TestWalletBalance(t *testing.T) {
tests := []struct {
name string
setupMock func(*mockChainForBalance)
expectedValue *big.Int
}{
{
name: "successful balance fetch",
setupMock: func(m *mockChainForBalance) {
balance := big.NewInt(1000000000000000000) // 1 ETH
m.client.On("BalanceAt", mock.Anything, mock.Anything, mock.Anything).Return(balance, nil)
},
expectedValue: big.NewInt(1000000000000000000),
},
{
name: "balance fetch error returns zero",
setupMock: func(m *mockChainForBalance) {
m.client.On("BalanceAt", mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError)
},
expectedValue: new(big.Int),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockChain := &mockChainForBalance{
client: new(mockEthClient),
}
tt.setupMock(mockChain)

w := &testWallet{
privateKey: "test-key",
address: types.Address{},
chain: mockChain,
}

balance := w.Balance()
assert.Equal(t, 0, balance.Int.Cmp(tt.expectedValue))

mockChain.AssertExpectations(t)
mockChain.client.AssertExpectations(t)
})
}
}
20 changes: 19 additions & 1 deletion devnet-sdk/types/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,34 @@ func (b Balance) Mul(f float64) Balance {

// GreaterThan returns true if this balance is greater than other
func (b Balance) GreaterThan(other Balance) bool {
if b.Int == nil {
return false
}
if other.Int == nil {
return true
}
return b.Int.Cmp(other.Int) > 0
}

// LessThan returns true if this balance is less than other
func (b Balance) LessThan(other Balance) bool {
if b.Int == nil {
return other.Int != nil
}
if other.Int == nil {
return false
}
return b.Int.Cmp(other.Int) < 0
}

// Equal returns true if this balance equals other
func (b Balance) Equal(other Balance) bool {
if b.Int == nil {
return other.Int == nil
}
if other.Int == nil {
return false
}
return b.Int.Cmp(other.Int) == 0
}

Expand All @@ -59,7 +77,7 @@ func (b Balance) LogValue() slog.Value {

// 1 ETH = 1e18 Wei
if eth.Cmp(new(big.Float).SetFloat64(0.001)) >= 0 {
str := eth.Text('g', 3)
str := eth.Text('f', 0)
return slog.StringValue(fmt.Sprintf("%s ETH", str))
}

Expand Down
201 changes: 161 additions & 40 deletions devnet-sdk/types/balance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package types
import (
"math/big"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewBalance(t *testing.T) {
Expand Down Expand Up @@ -87,61 +89,180 @@ func TestBalance_Mul(t *testing.T) {
}
}

func TestBalance_Comparisons(t *testing.T) {
func TestBalanceComparisons(t *testing.T) {
tests := []struct {
a, b int64
gt, lt, eq bool
name string
balance1 Balance
balance2 Balance
greater bool
less bool
equal bool
}{
{100, 200, false, true, false},
{200, 100, true, false, false},
{100, 100, false, false, true},
{0, 100, false, true, false},
{
name: "both nil",
balance1: Balance{},
balance2: Balance{},
greater: false,
less: false,
equal: true,
},
{
name: "first nil",
balance1: Balance{},
balance2: NewBalance(big.NewInt(100)),
greater: false,
less: true,
equal: false,
},
{
name: "second nil",
balance1: NewBalance(big.NewInt(100)),
balance2: Balance{},
greater: true,
less: false,
equal: false,
},
{
name: "first greater",
balance1: NewBalance(big.NewInt(200)),
balance2: NewBalance(big.NewInt(100)),
greater: true,
less: false,
equal: false,
},
{
name: "second greater",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(200)),
greater: false,
less: true,
equal: false,
},
{
name: "equal values",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(100)),
greater: false,
less: false,
equal: true,
},
{
name: "zero values",
balance1: NewBalance(new(big.Int)),
balance2: NewBalance(new(big.Int)),
greater: false,
less: false,
equal: true,
},
}

for _, tt := range tests {
a := NewBalance(big.NewInt(tt.a))
b := NewBalance(big.NewInt(tt.b))
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.greater, tt.balance1.GreaterThan(tt.balance2), "GreaterThan check failed")
assert.Equal(t, tt.less, tt.balance1.LessThan(tt.balance2), "LessThan check failed")
assert.Equal(t, tt.equal, tt.balance1.Equal(tt.balance2), "Equal check failed")
})
}
}

if got := a.GreaterThan(b); got != tt.gt {
t.Errorf("GreaterThan(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.gt)
}
func TestBalanceArithmetic(t *testing.T) {
tests := []struct {
name string
balance1 Balance
balance2 Balance
add *big.Int
sub *big.Int
mul float64
mulRes *big.Int
}{
{
name: "basic arithmetic",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(50)),
add: big.NewInt(150),
sub: big.NewInt(50),
mul: 2.5,
mulRes: big.NewInt(250),
},
{
name: "zero values",
balance1: NewBalance(new(big.Int)),
balance2: NewBalance(new(big.Int)),
add: new(big.Int),
sub: new(big.Int),
mul: 1.0,
mulRes: new(big.Int),
},
{
name: "large numbers",
balance1: NewBalance(new(big.Int).Mul(big.NewInt(1e18), big.NewInt(100))), // 100 ETH
balance2: NewBalance(new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50))), // 50 ETH
add: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(150)), // 150 ETH
sub: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50)), // 50 ETH
mul: 0.5,
mulRes: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50)), // 50 ETH
},
}

if got := a.LessThan(b); got != tt.lt {
t.Errorf("LessThan(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.lt)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test Add
sum := tt.balance1.Add(tt.balance2)
assert.Equal(t, 0, sum.Int.Cmp(tt.add), "Add result mismatch")

if got := a.Equal(b); got != tt.eq {
t.Errorf("Equal(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.eq)
}
// Test Sub
diff := tt.balance1.Sub(tt.balance2)
assert.Equal(t, 0, diff.Int.Cmp(tt.sub), "Sub result mismatch")

// Test Mul
product := tt.balance1.Mul(tt.mul)
assert.Equal(t, 0, product.Int.Cmp(tt.mulRes), "Mul result mismatch")
})
}
}

func TestBalance_LogValue(t *testing.T) {
func TestBalanceLogValue(t *testing.T) {
tests := []struct {
wei string // Using string to handle large numbers
want string
name string
balance Balance
expected string
}{
{"2000000000000000000", "2 ETH"}, // 2 ETH
{"1000000000", "1 Gwei"}, // 1 Gwei
{"100", "100 Wei"}, // 100 Wei
{"1500000000000000000", "1.5 ETH"}, // 1.5 ETH
{"0", "0 Wei"}, // 0
{
name: "nil balance",
balance: Balance{},
expected: "0 ETH",
},
{
name: "zero balance",
balance: NewBalance(new(big.Int)),
expected: "0 Wei",
},
{
name: "small wei amount",
balance: NewBalance(big.NewInt(100)),
expected: "100 Wei",
},
{
name: "gwei amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1), big.NewInt(1e9))),
expected: "1 Gwei",
},
{
name: "eth amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1), big.NewInt(1e18))),
expected: "1 ETH",
},
{
name: "large eth amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18))),
expected: "1000 ETH",
},
}

for _, tt := range tests {
i := new(big.Int)
i.SetString(tt.wei, 10)
b := NewBalance(i)
got := b.LogValue().String()
if got != tt.want {
t.Errorf("LogValue() for %v Wei = %v, want %v", tt.wei, got, tt.want)
}
}

// Test nil case
var nilBalance Balance
got := nilBalance.LogValue().String()
if got != "0 ETH" {
t.Errorf("LogValue() for nil balance = %v, want '0 ETH'", got)
t.Run(tt.name, func(t *testing.T) {
logValue := tt.balance.LogValue()
assert.Equal(t, tt.expected, logValue.String())
})
}
}

0 comments on commit d2ca378

Please sign in to comment.