Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(counters): smt depth fix and test #1695

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-unwinds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
workflow_dispatch:

jobs:
fixing-unwinds-tests:
unwind-tests:
runs-on: ubuntu-22.04

steps:
Expand Down
11 changes: 7 additions & 4 deletions smt/pkg/db/mdbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const TableAccountValues = "HermezSmtAccountValues"
const TableMetadata = "HermezSmtMetadata"
const TableHashKey = "HermezSmtHashKey"

const MetaLastRoot = "lastRoot"
const MetaDepth = "depth"

var HermezSmtTables = []string{TableSmt, TableStats, TableAccountValues, TableMetadata, TableHashKey}

type EriDb struct {
Expand Down Expand Up @@ -219,7 +222,7 @@ func (m *EriDb) RollbackBatch() {
}

func (m *EriRoDb) GetLastRoot() (*big.Int, error) {
data, err := m.kvTxRo.GetOne(TableStats, []byte("lastRoot"))
data, err := m.kvTxRo.GetOne(TableStats, []byte(MetaLastRoot))
if err != nil {
return big.NewInt(0), err
}
Expand All @@ -233,11 +236,11 @@ func (m *EriRoDb) GetLastRoot() (*big.Int, error) {

func (m *EriDb) SetLastRoot(r *big.Int) error {
v := utils.ConvertBigIntToHex(r)
return m.tx.Put(TableStats, []byte("lastRoot"), []byte(v))
return m.tx.Put(TableStats, []byte(MetaLastRoot), []byte(v))
}

func (m *EriRoDb) GetDepth() (uint8, error) {
data, err := m.kvTxRo.GetOne(TableStats, []byte("depth"))
data, err := m.kvTxRo.GetOne(TableStats, []byte(MetaDepth))
if err != nil {
return 0, err
}
Expand All @@ -250,7 +253,7 @@ func (m *EriRoDb) GetDepth() (uint8, error) {
}

func (m *EriDb) SetDepth(depth uint8) error {
return m.tx.Put(TableStats, []byte("lastRoot"), []byte{depth})
return m.tx.Put(TableStats, []byte(MetaDepth), []byte{depth})
}

func (m *EriRoDb) Get(key utils.NodeKey) (utils.NodeValue12, error) {
Expand Down
21 changes: 12 additions & 9 deletions smt/pkg/smt/smt_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ func (s *SMT) GenerateFromKVBulk(ctx context.Context, logPrefix string, nodeKeys

var buildSmtLoopErr error
var rootNode *SmtNode
var maxDepth int
tempTreeBuildStart := time.Now()
leafValueMap := sync.Map{}
accountValuesReadChan := make(chan *utils.NodeValue8, 1024)
go func() {
defer wg.Done()
defer deletesWorker.Stop()
rootNode, buildSmtLoopErr = runBuildSmtLoop(s, logPrefix, nodeKeys, &leafValueMap, deletesWorker, accountValuesReadChan)
rootNode, maxDepth, buildSmtLoopErr = runBuildSmtLoop(s, logPrefix, nodeKeys, &leafValueMap, deletesWorker, accountValuesReadChan)
}()

// startBuildSmtLoopDbCompanionLoop is blocking operation. It continue only when the last result is saved
Expand Down Expand Up @@ -124,10 +125,14 @@ func (s *SMT) GenerateFromKVBulk(ctx context.Context, logPrefix string, nodeKeys
return [4]uint64{}, err
}

if err := s.updateDepth(maxDepth); err != nil {
return [4]uint64{}, err
}

return finalRoot, nil
}

func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafValueMap *sync.Map, deletesWorker *utils.Worker, accountValuesReadChan <-chan *utils.NodeValue8) (*SmtNode, error) {
func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafValueMap *sync.Map, deletesWorker *utils.Worker, accountValuesReadChan <-chan *utils.NodeValue8) (*SmtNode, int, error) {
totalKeysCount := len(nodeKeys)
insertedKeysCount := uint64(0)
maxReachedLevel := 0
Expand All @@ -148,7 +153,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
keys := k.GetPath()
vPointer := <-accountValuesReadChan
if vPointer == nil {
return nil, fmt.Errorf("the actual error is returned by main DB thread")
return nil, 0, fmt.Errorf("the actual error is returned by main DB thread")
}
v := *vPointer
leafValueMap.Store(k, &v)
Expand Down Expand Up @@ -202,7 +207,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
//sanity check - new leaf should be on the right side
//otherwise something went wrong
if leaf0.rKey[level2] != 0 || keys[level2+level] != 1 {
return nil, fmt.Errorf(
return nil, 0, fmt.Errorf(
"leaf insert error. new leaf should be on the right of the old, oldLeaf: %v, newLeaf: %v",
append(keys[:level+1], leaf0.rKey[level2:]...),
keys,
Expand Down Expand Up @@ -264,7 +269,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
// this is case for 1 leaf inserted to the left of the root node
if len(siblings) == 0 && keys[0] == 0 {
if upperNode.node0 != nil {
return nil, fmt.Errorf("tried to override left node")
return nil, 0, fmt.Errorf("tried to override left node")
}
upperNode.node0 = newNode
} else {
Expand All @@ -273,7 +278,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
//the new leaf should be on the right side
//otherwise something went wrong
if upperNode.node1 != nil || keys[level] != 1 {
return nil, fmt.Errorf(
return nil, 0, fmt.Errorf(
"leaf insert error. new should be on the right of the found node, foundNode: %v, newLeafKey: %v",
upperNode.node1,
keys,
Expand Down Expand Up @@ -318,9 +323,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
progressChan <- uint64(totalKeysCount) + insertedKeysCount
}

s.updateDepth(maxReachedLevel)

return &rootNode, nil
return &rootNode, maxReachedLevel, nil
}

func startBuildSmtLoopDbCompanionLoop(s *SMT, nodeKeys []utils.NodeKey, jobResultsChannel chan utils.JobResult, accountValuesReadChan chan *utils.NodeValue8) error {
Expand Down
18 changes: 15 additions & 3 deletions turbo/jsonrpc/zkevm_counters.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,28 @@ func populateCounters(collected *vm.Counters, execResult *core.ExecutionResult,
return resJson, nil
}

func getSmtDepth(hermezDb *hermez_db.HermezDbReader, blockNum uint64, config *tracers.TraceConfig_ZkEvm) (int, error) {
var smtDepth int
type IDepthGetter interface {
GetClosestSmtDepth(blockNum uint64) (depthBlockNum uint64, smtDepth uint64, err error)
}

func getSmtDepth(
hermezDb IDepthGetter,
blockNum uint64,
config *tracers.TraceConfig_ZkEvm,
) (smtDepth int, err error) {
if config != nil && config.SmtDepth != nil {
smtDepth = *config.SmtDepth
} else {
depthBlockNum, smtDepth, err := hermezDb.GetClosestSmtDepth(blockNum)
var depthBlockNum uint64
var smtDepthUint64 uint64

depthBlockNum, smtDepthUint64, err = hermezDb.GetClosestSmtDepth(blockNum)
if err != nil {
return 0, err
}

smtDepth = int(smtDepthUint64)

if depthBlockNum < blockNum {
smtDepth += smtDepth / 10
}
Expand Down
146 changes: 146 additions & 0 deletions turbo/jsonrpc/zkevm_counters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package jsonrpc

import (
"testing"
"errors"
"github.com/ledgerwatch/erigon/eth/tracers"
)

type MockDepthGetter struct {
DepthBlockNum uint64
SMTDepth uint64
Err error
}

func (m *MockDepthGetter) GetClosestSmtDepth(blockNum uint64) (uint64, uint64, error) {
return m.DepthBlockNum, m.SMTDepth, m.Err
}

func intPtr(i int) *int {
return &i
}

func TestGetSmtDepth(t *testing.T) {
testCases := map[string]struct {
blockNum uint64
config *tracers.TraceConfig_ZkEvm
mockSetup func(m *MockDepthGetter)
expectedDepth int
expectedErr error
}{
"Config provided with SmtDepth": {
blockNum: 100,
config: &tracers.TraceConfig_ZkEvm{
SmtDepth: intPtr(128),
},
mockSetup: func(m *MockDepthGetter) {
// No DB call expected.
},
expectedDepth: 128,
expectedErr: nil,
},
"Config is nil, GetClosestSmtDepth returns depthBlockNum < blockNum": {
blockNum: 100,
config: nil,
mockSetup: func(m *MockDepthGetter) {
m.DepthBlockNum = 90
m.SMTDepth = 100
m.Err = nil
},
expectedDepth: 110, // 100 + 100/10
expectedErr: nil,
},
"Config is nil, GetClosestSmtDepth returns depthBlockNum >= blockNum": {
blockNum: 100,
config: nil,
mockSetup: func(m *MockDepthGetter) {
m.DepthBlockNum = 100
m.SMTDepth = 100
m.Err = nil
},
expectedDepth: 100,
expectedErr: nil,
},
"Config is nil, smtDepth after adjustment exceeds 256": {
blockNum: 100,
config: nil,
mockSetup: func(m *MockDepthGetter) {
m.DepthBlockNum = 90
m.SMTDepth = 250
m.Err = nil
},
expectedDepth: 256, // 250 + 25 = 275 -> capped to 256
expectedErr: nil,
},
"Config is nil, smtDepth is 0": {
blockNum: 100,
config: nil,
mockSetup: func(m *MockDepthGetter) {
m.DepthBlockNum = 90
m.SMTDepth = 0
m.Err = nil
},
expectedDepth: 256, // 0 is invalid, set to 256
expectedErr: nil,
},
"Config is nil, GetClosestSmtDepth returns error": {
blockNum: 100,
config: nil,
mockSetup: func(m *MockDepthGetter) {
m.DepthBlockNum = 0
m.SMTDepth = 0
m.Err = errors.New("database error")
},
expectedDepth: 0,
expectedErr: errors.New("database error"),
},
"Config provided with SmtDepth exceeding 256": {
blockNum: 100,
config: &tracers.TraceConfig_ZkEvm{
SmtDepth: intPtr(300),
},
mockSetup: func(m *MockDepthGetter) {
// No DB call expected.
},
expectedDepth: 300, // As per the function logic, returned as-is.
expectedErr: nil,
},
"Config provided with SmtDepth set to 0": {
blockNum: 100,
config: &tracers.TraceConfig_ZkEvm{
SmtDepth: intPtr(0),
},
mockSetup: func(m *MockDepthGetter) {
// No DB call expected.
},
expectedDepth: 0, // As per the function logic, returned as-is.
expectedErr: nil,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
mock := &MockDepthGetter{}
tc.mockSetup(mock)

actualDepth, actualErr := getSmtDepth(mock, tc.blockNum, tc.config)

if tc.expectedErr != nil {
if actualErr == nil {
t.Fatalf("expected error '%v', but got nil", tc.expectedErr)
}
if actualErr.Error() != tc.expectedErr.Error() {
t.Fatalf("expected error '%v', but got '%v'", tc.expectedErr, actualErr)
}
} else {
if actualErr != nil {
t.Fatalf("expected no error, but got '%v'", actualErr)
}
}

if actualDepth != tc.expectedDepth {
t.Errorf("expected smtDepth %d, but got %d", tc.expectedDepth, actualDepth)
}
})
}
}
10 changes: 5 additions & 5 deletions zk/smt/unwind_smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R
eridb.OpenBatch(quit)
}

changesGetter := NewChangesGetter(tx)
if err := changesGetter.openChangesGetter(from); err != nil {
cg := NewChangesGetter(tx)
if err := cg.openChangesGetter(from); err != nil {
return trie.EmptyRoot, fmt.Errorf("OpenChangesGetter: %w", err)
}
defer changesGetter.closeChangesGetter()
defer cg.closeChangesGetter()

total := uint64(math.Abs(float64(from) - float64(to) + 1))
progressChan, stopPrinter := zk.ProgressPrinter(fmt.Sprintf("[%s] Progress unwinding", logPrefix), total, quiet)
Expand All @@ -58,7 +58,7 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R
default:
}

if err := changesGetter.getChangesForBlock(i); err != nil {
if err := cg.getChangesForBlock(i); err != nil {
return trie.EmptyRoot, fmt.Errorf("getChangesForBlock: %w", err)
}

Expand All @@ -67,7 +67,7 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R

stopPrinter()

if _, _, err := dbSmt.SetStorage(ctx, logPrefix, changesGetter.accChanges, changesGetter.codeChanges, changesGetter.storageChanges); err != nil {
if _, _, err := dbSmt.SetStorage(ctx, logPrefix, cg.accChanges, cg.codeChanges, cg.storageChanges); err != nil {
return trie.EmptyRoot, err
}

Expand Down
Loading
Loading