Skip to content

Commit

Permalink
fix(devnet-sdk): ensure balances are comparable
Browse files Browse the repository at this point in the history
In error cases, make sure we can still compare resulting balances
without crashing.

In particular, this makes interop_smoke_test skip properly if it can't
find a funded wallet to use.
  • Loading branch information
sigma committed Feb 10, 2025
1 parent 6b849f6 commit df3cfd2
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 43 deletions.
5 changes: 3 additions & 2 deletions devnet-sdk/system/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package system
import (
"context"
"fmt"
"math/big"
"strings"

"github.com/ethereum-optimism/optimism/devnet-sdk/types"
Expand Down Expand Up @@ -52,12 +53,12 @@ func (w *wallet) SendETH(to types.Address, amount types.Balance) types.WriteInvo
func (w *wallet) Balance() types.Balance {
client, err := w.chain.getClient()
if err != nil {
return types.Balance{}
return types.NewBalance(new(big.Int))
}

balance, err := client.BalanceAt(context.Background(), w.address, nil)
if err != nil {
return types.Balance{}
return types.NewBalance(new(big.Int))
}

return types.NewBalance(balance)
Expand Down
92 changes: 92 additions & 0 deletions devnet-sdk/system/wallet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package system

import (
"context"
"math/big"
"testing"

"github.com/ethereum-optimism/optimism/devnet-sdk/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// testWallet is a minimal wallet implementation for testing balance functionality
type testWallet struct {
privateKey types.Key
address types.Address
chain *mockChainForBalance // Use concrete type to access mock client directly
}

func (w *testWallet) Balance() types.Balance {
// Use the mock client directly instead of going through getClient()
balance, err := w.chain.client.BalanceAt(context.Background(), w.address, nil)
if err != nil {
return types.NewBalance(new(big.Int))
}

return types.NewBalance(balance)
}

// mockEthClient implements a mock ethereum client for testing
type mockEthClient struct {
mock.Mock
}

func (m *mockEthClient) BalanceAt(ctx context.Context, account types.Address, blockNumber *big.Int) (*big.Int, error) {
args := m.Called(ctx, account, blockNumber)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*big.Int), args.Error(1)
}

// mockChainForBalance implements just enough of the chain interface for balance testing
type mockChainForBalance struct {
mock.Mock
client *mockEthClient
}

func TestWalletBalance(t *testing.T) {
tests := []struct {
name string
setupMock func(*mockChainForBalance)
expectedValue *big.Int
}{
{
name: "successful balance fetch",
setupMock: func(m *mockChainForBalance) {
balance := big.NewInt(1000000000000000000) // 1 ETH
m.client.On("BalanceAt", mock.Anything, mock.Anything, mock.Anything).Return(balance, nil)
},
expectedValue: big.NewInt(1000000000000000000),
},
{
name: "balance fetch error returns zero",
setupMock: func(m *mockChainForBalance) {
m.client.On("BalanceAt", mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError)
},
expectedValue: new(big.Int),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockChain := &mockChainForBalance{
client: new(mockEthClient),
}
tt.setupMock(mockChain)

w := &testWallet{
privateKey: "test-key",
address: types.Address{},
chain: mockChain,
}

balance := w.Balance()
assert.Equal(t, 0, balance.Int.Cmp(tt.expectedValue))

mockChain.AssertExpectations(t)
mockChain.client.AssertExpectations(t)
})
}
}
20 changes: 19 additions & 1 deletion devnet-sdk/types/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,34 @@ func (b Balance) Mul(f float64) Balance {

// GreaterThan returns true if this balance is greater than other
func (b Balance) GreaterThan(other Balance) bool {
if b.Int == nil {
return false
}
if other.Int == nil {
return true
}
return b.Int.Cmp(other.Int) > 0
}

// LessThan returns true if this balance is less than other
func (b Balance) LessThan(other Balance) bool {
if b.Int == nil {
return other.Int != nil
}
if other.Int == nil {
return false
}
return b.Int.Cmp(other.Int) < 0
}

// Equal returns true if this balance equals other
func (b Balance) Equal(other Balance) bool {
if b.Int == nil {
return other.Int == nil
}
if other.Int == nil {
return false
}
return b.Int.Cmp(other.Int) == 0
}

Expand All @@ -59,7 +77,7 @@ func (b Balance) LogValue() slog.Value {

// 1 ETH = 1e18 Wei
if eth.Cmp(new(big.Float).SetFloat64(0.001)) >= 0 {
str := eth.Text('g', 3)
str := eth.Text('f', 0)
return slog.StringValue(fmt.Sprintf("%s ETH", str))
}

Expand Down
201 changes: 161 additions & 40 deletions devnet-sdk/types/balance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package types
import (
"math/big"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewBalance(t *testing.T) {
Expand Down Expand Up @@ -87,61 +89,180 @@ func TestBalance_Mul(t *testing.T) {
}
}

func TestBalance_Comparisons(t *testing.T) {
func TestBalanceComparisons(t *testing.T) {
tests := []struct {
a, b int64
gt, lt, eq bool
name string
balance1 Balance
balance2 Balance
greater bool
less bool
equal bool
}{
{100, 200, false, true, false},
{200, 100, true, false, false},
{100, 100, false, false, true},
{0, 100, false, true, false},
{
name: "both nil",
balance1: Balance{},
balance2: Balance{},
greater: false,
less: false,
equal: true,
},
{
name: "first nil",
balance1: Balance{},
balance2: NewBalance(big.NewInt(100)),
greater: false,
less: true,
equal: false,
},
{
name: "second nil",
balance1: NewBalance(big.NewInt(100)),
balance2: Balance{},
greater: true,
less: false,
equal: false,
},
{
name: "first greater",
balance1: NewBalance(big.NewInt(200)),
balance2: NewBalance(big.NewInt(100)),
greater: true,
less: false,
equal: false,
},
{
name: "second greater",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(200)),
greater: false,
less: true,
equal: false,
},
{
name: "equal values",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(100)),
greater: false,
less: false,
equal: true,
},
{
name: "zero values",
balance1: NewBalance(new(big.Int)),
balance2: NewBalance(new(big.Int)),
greater: false,
less: false,
equal: true,
},
}

for _, tt := range tests {
a := NewBalance(big.NewInt(tt.a))
b := NewBalance(big.NewInt(tt.b))
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.greater, tt.balance1.GreaterThan(tt.balance2), "GreaterThan check failed")
assert.Equal(t, tt.less, tt.balance1.LessThan(tt.balance2), "LessThan check failed")
assert.Equal(t, tt.equal, tt.balance1.Equal(tt.balance2), "Equal check failed")
})
}
}

if got := a.GreaterThan(b); got != tt.gt {
t.Errorf("GreaterThan(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.gt)
}
func TestBalanceArithmetic(t *testing.T) {
tests := []struct {
name string
balance1 Balance
balance2 Balance
add *big.Int
sub *big.Int
mul float64
mulRes *big.Int
}{
{
name: "basic arithmetic",
balance1: NewBalance(big.NewInt(100)),
balance2: NewBalance(big.NewInt(50)),
add: big.NewInt(150),
sub: big.NewInt(50),
mul: 2.5,
mulRes: big.NewInt(250),
},
{
name: "zero values",
balance1: NewBalance(new(big.Int)),
balance2: NewBalance(new(big.Int)),
add: new(big.Int),
sub: new(big.Int),
mul: 1.0,
mulRes: new(big.Int),
},
{
name: "large numbers",
balance1: NewBalance(new(big.Int).Mul(big.NewInt(1e18), big.NewInt(100))), // 100 ETH
balance2: NewBalance(new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50))), // 50 ETH
add: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(150)), // 150 ETH
sub: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50)), // 50 ETH
mul: 0.5,
mulRes: new(big.Int).Mul(big.NewInt(1e18), big.NewInt(50)), // 50 ETH
},
}

if got := a.LessThan(b); got != tt.lt {
t.Errorf("LessThan(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.lt)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test Add
sum := tt.balance1.Add(tt.balance2)
assert.Equal(t, 0, sum.Int.Cmp(tt.add), "Add result mismatch")

if got := a.Equal(b); got != tt.eq {
t.Errorf("Equal(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.eq)
}
// Test Sub
diff := tt.balance1.Sub(tt.balance2)
assert.Equal(t, 0, diff.Int.Cmp(tt.sub), "Sub result mismatch")

// Test Mul
product := tt.balance1.Mul(tt.mul)
assert.Equal(t, 0, product.Int.Cmp(tt.mulRes), "Mul result mismatch")
})
}
}

func TestBalance_LogValue(t *testing.T) {
func TestBalanceLogValue(t *testing.T) {
tests := []struct {
wei string // Using string to handle large numbers
want string
name string
balance Balance
expected string
}{
{"2000000000000000000", "2 ETH"}, // 2 ETH
{"1000000000", "1 Gwei"}, // 1 Gwei
{"100", "100 Wei"}, // 100 Wei
{"1500000000000000000", "1.5 ETH"}, // 1.5 ETH
{"0", "0 Wei"}, // 0
{
name: "nil balance",
balance: Balance{},
expected: "0 ETH",
},
{
name: "zero balance",
balance: NewBalance(new(big.Int)),
expected: "0 Wei",
},
{
name: "small wei amount",
balance: NewBalance(big.NewInt(100)),
expected: "100 Wei",
},
{
name: "gwei amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1), big.NewInt(1e9))),
expected: "1 Gwei",
},
{
name: "eth amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1), big.NewInt(1e18))),
expected: "1 ETH",
},
{
name: "large eth amount",
balance: NewBalance(new(big.Int).Mul(big.NewInt(1000), big.NewInt(1e18))),
expected: "1000 ETH",
},
}

for _, tt := range tests {
i := new(big.Int)
i.SetString(tt.wei, 10)
b := NewBalance(i)
got := b.LogValue().String()
if got != tt.want {
t.Errorf("LogValue() for %v Wei = %v, want %v", tt.wei, got, tt.want)
}
}

// Test nil case
var nilBalance Balance
got := nilBalance.LogValue().String()
if got != "0 ETH" {
t.Errorf("LogValue() for nil balance = %v, want '0 ETH'", got)
t.Run(tt.name, func(t *testing.T) {
logValue := tt.balance.LogValue()
assert.Equal(t, tt.expected, logValue.String())
})
}
}

0 comments on commit df3cfd2

Please sign in to comment.