diff --git a/sequencer/dbmanager.go b/sequencer/dbmanager.go index f8a85f6e8a..6f132d034a 100644 --- a/sequencer/dbmanager.go +++ b/sequencer/dbmanager.go @@ -201,6 +201,14 @@ func (d *dbManager) sendDataToStreamer() { } for _, l2Transaction := range l2Transactions { + // Populate intermediate state root + position := state.GetSystemSCPosition(blockStart.L2BlockNumber) + imStateRoot, err := d.GetStorageAt(context.Background(), common.HexToAddress(state.SystemSC), big.NewInt(0).SetBytes(position), l2Block.StateRoot) + if err != nil { + log.Errorf("failed to get storage at for l2block %v: %v", l2Block.L2BlockNumber, err) + } + l2Transaction.StateRoot = common.BigToHash(imStateRoot) + _, err = d.streamServer.AddStreamEntry(state.EntryTypeL2Tx, l2Transaction.Encode()) if err != nil { log.Errorf("failed to add l2tx stream entry for l2block %v: %v", l2Block.L2BlockNumber, err) @@ -726,3 +734,8 @@ func (d *dbManager) GetForcedBatch(ctx context.Context, forcedBatchNumber uint64 func (d *dbManager) GetForkIDByBatchNumber(batchNumber uint64) uint64 { return d.state.GetForkIDByBatchNumber(batchNumber) } + +// GetStorageAt returns the storage at a given address and position +func (d *dbManager) GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) { + return d.state.GetStorageAt(ctx, address, position, root) +} diff --git a/sequencer/interfaces.go b/sequencer/interfaces.go index 087fbc7464..19e185dc7a 100644 --- a/sequencer/interfaces.go +++ b/sequencer/interfaces.go @@ -86,6 +86,7 @@ type stateInterface interface { GetDSBatches(ctx context.Context, firstBatchNumber, lastBatchNumber uint64, readWIPBatch bool, dbTx pgx.Tx) ([]*state.DSBatch, error) GetDSL2Blocks(ctx context.Context, firstBatchNumber, lastBatchNumber uint64, dbTx pgx.Tx) ([]*state.DSL2Block, error) GetDSL2Transactions(ctx context.Context, firstL2Block, lastL2Block uint64, dbTx pgx.Tx) ([]*state.DSL2Transaction, error) + GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) } type workerInterface interface { @@ -137,4 +138,5 @@ type dbManagerInterface interface { StoreProcessedTxAndDeleteFromPool(ctx context.Context, tx transactionToStore) error GetForcedBatch(ctx context.Context, forcedBatchNumber uint64, dbTx pgx.Tx) (*state.ForcedBatch, error) GetForkIDByBatchNumber(batchNumber uint64) uint64 + GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) } diff --git a/sequencer/mock_db_manager.go b/sequencer/mock_db_manager.go index c969f4c90c..f8d1a815ec 100644 --- a/sequencer/mock_db_manager.go +++ b/sequencer/mock_db_manager.go @@ -545,6 +545,32 @@ func (_m *DbManagerMock) GetLatestVirtualBatchTimestamp(ctx context.Context, dbT return r0, r1 } +// GetStorageAt provides a mock function with given fields: ctx, address, position, root +func (_m *DbManagerMock) GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) { + ret := _m.Called(ctx, address, position, root) + + var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, common.Hash) (*big.Int, error)); ok { + return rf(ctx, address, position, root) + } + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, common.Hash) *big.Int); ok { + r0 = rf(ctx, address, position, root) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, common.Address, *big.Int, common.Hash) error); ok { + r1 = rf(ctx, address, position, root) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetStoredFlushID provides a mock function with given fields: ctx func (_m *DbManagerMock) GetStoredFlushID(ctx context.Context) (uint64, string, error) { ret := _m.Called(ctx) diff --git a/sequencer/mock_state.go b/sequencer/mock_state.go index 7d8717866e..271975328a 100644 --- a/sequencer/mock_state.go +++ b/sequencer/mock_state.go @@ -746,6 +746,32 @@ func (_m *StateMock) GetNonceByStateRoot(ctx context.Context, address common.Add return r0, r1 } +// GetStorageAt provides a mock function with given fields: ctx, address, position, root +func (_m *StateMock) GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) { + ret := _m.Called(ctx, address, position, root) + + var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, common.Hash) (*big.Int, error)); ok { + return rf(ctx, address, position, root) + } + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int, common.Hash) *big.Int); ok { + r0 = rf(ctx, address, position, root) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, common.Address, *big.Int, common.Hash) error); ok { + r1 = rf(ctx, address, position, root) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetStoredFlushID provides a mock function with given fields: ctx func (_m *StateMock) GetStoredFlushID(ctx context.Context) (uint64, string, error) { ret := _m.Called(ctx) diff --git a/sequencer/sequencer.go b/sequencer/sequencer.go index f572d9818b..2494f100fc 100644 --- a/sequencer/sequencer.go +++ b/sequencer/sequencer.go @@ -140,7 +140,7 @@ func (s *Sequencer) Start(ctx context.Context) { } func (s *Sequencer) updateDataStreamerFile(ctx context.Context, streamServer *datastreamer.StreamServer) { - err := state.GenerateDataStreamerFile(ctx, streamServer, s.state, true) + err := state.GenerateDataStreamerFile(ctx, streamServer, s.state, true, nil) if err != nil { log.Fatalf("failed to generate data streamer file, err: %v", err) } diff --git a/state/datastream.go b/state/datastream.go index e71251d2ee..ddc7fc13be 100644 --- a/state/datastream.go +++ b/state/datastream.go @@ -3,10 +3,13 @@ package state import ( "context" "encoding/binary" + "math/big" "github.com/0xPolygonHermez/zkevm-data-streamer/datastreamer" "github.com/0xPolygonHermez/zkevm-node/log" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/iden3/go-iden3-crypto/keccak256" "github.com/jackc/pgx/v4" ) @@ -25,6 +28,10 @@ const ( EntryTypeUpdateGER datastreamer.EntryType = 4 // BookMarkTypeL2Block represents a L2 block bookmark BookMarkTypeL2Block byte = 0 + // SystemSC is the system smart contract address + SystemSC = "0x000000000000000000000000000000005ca1ab1e" + // posConstant is the constant used to compute the position of the intermediate state root + posConstant = 1 ) // DSBatch represents a data stream batch @@ -92,10 +99,11 @@ func (b DSL2BlockStart) Decode(data []byte) DSL2BlockStart { // DSL2Transaction represents a data stream L2 transaction type DSL2Transaction struct { - L2BlockNumber uint64 // Not included in the encoded data - EffectiveGasPricePercentage uint8 // 1 byte - IsValid uint8 // 1 byte - EncodedLength uint32 // 4 bytes + L2BlockNumber uint64 // Not included in the encoded data + EffectiveGasPricePercentage uint8 // 1 byte + IsValid uint8 // 1 byte + StateRoot common.Hash // 32 bytes + EncodedLength uint32 // 4 bytes Encoded []byte } @@ -104,6 +112,7 @@ func (l DSL2Transaction) Encode() []byte { bytes := make([]byte, 0) bytes = append(bytes, byte(l.EffectiveGasPricePercentage)) bytes = append(bytes, byte(l.IsValid)) + bytes = append(bytes, l.StateRoot[:]...) bytes = binary.LittleEndian.AppendUint32(bytes, l.EncodedLength) bytes = append(bytes, l.Encoded...) return bytes @@ -113,8 +122,9 @@ func (l DSL2Transaction) Encode() []byte { func (l DSL2Transaction) Decode(data []byte) DSL2Transaction { l.EffectiveGasPricePercentage = uint8(data[0]) l.IsValid = uint8(data[1]) - l.EncodedLength = binary.LittleEndian.Uint32(data[2:6]) - l.Encoded = data[6:] + l.StateRoot = common.BytesToHash(data[2:34]) + l.EncodedLength = binary.LittleEndian.Uint32(data[34:38]) + l.Encoded = data[38:] return l } @@ -202,10 +212,12 @@ type DSState interface { GetDSBatches(ctx context.Context, firstBatchNumber, lastBatchNumber uint64, readWIPBatch bool, dbTx pgx.Tx) ([]*DSBatch, error) GetDSL2Blocks(ctx context.Context, firstBatchNumber, lastBatchNumber uint64, dbTx pgx.Tx) ([]*DSL2Block, error) GetDSL2Transactions(ctx context.Context, firstL2Block, lastL2Block uint64, dbTx pgx.Tx) ([]*DSL2Transaction, error) + GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root common.Hash) (*big.Int, error) + GetLastL2BlockHeader(ctx context.Context, dbTx pgx.Tx) (*types.Header, error) } // GenerateDataStreamerFile generates or resumes a data stream file -func GenerateDataStreamerFile(ctx context.Context, streamServer *datastreamer.StreamServer, stateDB DSState, readWIPBatch bool) error { +func GenerateDataStreamerFile(ctx context.Context, streamServer *datastreamer.StreamServer, stateDB DSState, readWIPBatch bool, imStateRoots *map[uint64][]byte) error { header := streamServer.GetHeader() var currentBatchNumber uint64 = 0 @@ -345,8 +357,6 @@ func GenerateDataStreamerFile(ctx context.Context, streamServer *datastreamer.St // Gererate full batches fullBatches := computeFullBatches(batches, l2Blocks, l2Txs) - log.Debugf("Full batches: %+v", fullBatches) - currentBatchNumber += limit for _, batch := range fullBatches { @@ -418,6 +428,18 @@ func GenerateDataStreamerFile(ctx context.Context, streamServer *datastreamer.St } for _, tx := range l2block.Txs { + // Populate intermediate state root + if imStateRoots == nil || (*imStateRoots)[blockStart.L2BlockNumber] == nil { + position := GetSystemSCPosition(l2block.L2BlockNumber) + imStateRoot, err := stateDB.GetStorageAt(ctx, common.HexToAddress(SystemSC), big.NewInt(0).SetBytes(position), l2block.StateRoot) + if err != nil { + return err + } + tx.StateRoot = common.BigToHash(imStateRoot) + } else { + tx.StateRoot = common.BytesToHash((*imStateRoots)[blockStart.L2BlockNumber]) + } + entry, err = streamServer.AddStreamEntry(EntryTypeL2Tx, tx.Encode()) if err != nil { return err @@ -447,6 +469,22 @@ func GenerateDataStreamerFile(ctx context.Context, streamServer *datastreamer.St return err } +// GetSystemSCPosition computes the position of the intermediate state root for the system smart contract +func GetSystemSCPosition(blockNumber uint64) []byte { + v1 := big.NewInt(0).SetUint64(blockNumber).Bytes() + v2 := big.NewInt(0).SetUint64(uint64(posConstant)).Bytes() + + // Add 0s to make v1 and v2 32 bytes long + for len(v1) < 32 { + v1 = append([]byte{0}, v1...) + } + for len(v2) < 32 { + v2 = append([]byte{0}, v2...) + } + + return keccak256.Hash(v1, v2) +} + // computeFullBatches computes the full batches func computeFullBatches(batches []*DSBatch, l2Blocks []*DSL2Block, l2Txs []*DSL2Transaction) []*DSFullBatch { currentL2Block := 0 diff --git a/state/test/datastream_test.go b/state/test/datastream_test.go index 9c2002b842..03829c15b6 100644 --- a/state/test/datastream_test.go +++ b/state/test/datastream_test.go @@ -1,7 +1,9 @@ package test import ( + "fmt" "testing" + "time" "github.com/0xPolygonHermez/zkevm-node/state" "github.com/ethereum/go-ethereum/common" @@ -28,14 +30,15 @@ func TestL2BlockStartEncode(t *testing.T) { func TestL2TransactionEncode(t *testing.T) { l2Transaction := state.DSL2Transaction{ - EffectiveGasPricePercentage: 128, // 1 byte - IsValid: 1, // 1 byte - EncodedLength: 5, // 4 bytes - Encoded: []byte{1, 2, 3, 4, 5}, // 5 bytes + EffectiveGasPricePercentage: 128, // 1 byte + IsValid: 1, // 1 byte + StateRoot: common.HexToHash("0x010203"), // 32 bytes + EncodedLength: 5, // 4 bytes + Encoded: []byte{1, 2, 3, 4, 5}, // 5 bytes } encoded := l2Transaction.Encode() - expected := []byte{128, 1, 5, 0, 0, 0, 1, 2, 3, 4, 5} + expected := []byte{128, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 5, 0, 0, 0, 1, 2, 3, 4, 5} assert.Equal(t, expected, encoded) } @@ -53,3 +56,15 @@ func TestL2BlockEndEncode(t *testing.T) { assert.Equal(t, expected, encoded) } + +func TestCalculateSCPosition(t *testing.T) { + a := time.Now() + blockNumber := uint64(2934867) + expected := common.HexToHash("0xaa93c484856be45716623765b429a967296594ca362e61e91d671fb422e0f744") + position := state.GetSystemSCPosition(blockNumber) + assert.Equal(t, expected, common.BytesToHash(position)) + b := time.Now() + + c := b.Sub(a) + fmt.Println(c) +} diff --git a/tools/datastreamer/config/config.go b/tools/datastreamer/config/config.go index 8161a6ec81..4bb0592c6e 100644 --- a/tools/datastreamer/config/config.go +++ b/tools/datastreamer/config/config.go @@ -8,7 +8,6 @@ import ( "github.com/0xPolygonHermez/zkevm-data-streamer/datastreamer" "github.com/0xPolygonHermez/zkevm-data-streamer/log" "github.com/0xPolygonHermez/zkevm-node/db" - "github.com/0xPolygonHermez/zkevm-node/merkletree" "github.com/0xPolygonHermez/zkevm-node/state/runtime/executor" "github.com/mitchellh/mapstructure" "github.com/spf13/viper" @@ -28,15 +27,21 @@ type OnlineConfig struct { StreamType datastreamer.StreamType `mapstructure:"StreamType"` } +// MTConfig is the configuration for the merkle tree +type MTConfig struct { + URI string `mapstructure:"URI"` + MaxThreads int `mapstructure:"MaxThreads"` +} + // Config is the configuration for the tool type Config struct { - ChainID uint64 `mapstructure:"ChainID"` - Online OnlineConfig `mapstructure:"Online"` - Offline datastreamer.Config `mapstructure:"Offline"` - StateDB db.Config `mapstructure:"StateDB"` - Executor executor.Config `mapstructure:"Executor"` - MerkeTree merkletree.Config `mapstructure:"MerkeTree"` - Log log.Config `mapstructure:"Log"` + ChainID uint64 `mapstructure:"ChainID"` + Online OnlineConfig `mapstructure:"Online"` + Offline datastreamer.Config `mapstructure:"Offline"` + StateDB db.Config `mapstructure:"StateDB"` + Executor executor.Config `mapstructure:"Executor"` + MerkleTree MTConfig `mapstructure:"MerkleTree"` + Log log.Config `mapstructure:"Log"` } // Default parses the default configuration values. diff --git a/tools/datastreamer/config/default.go b/tools/datastreamer/config/default.go index 7a2b8aea49..261906e50d 100644 --- a/tools/datastreamer/config/default.go +++ b/tools/datastreamer/config/default.go @@ -25,8 +25,9 @@ MaxConns = 200 URI = "zkevm-prover:50071" MaxGRPCMessageSize = 100000000 -[MerkeTree] +[MerkleTree] URI = "zkevm-prover:50061" +MaxThreads = 20 [Log] Environment = "development" # "production" or "development" diff --git a/tools/datastreamer/config/tool.config.toml b/tools/datastreamer/config/tool.config.toml index 2326418375..0e088db007 100644 --- a/tools/datastreamer/config/tool.config.toml +++ b/tools/datastreamer/config/tool.config.toml @@ -21,8 +21,9 @@ MaxConns = 200 URI = "zkevm-prover:50071" MaxGRPCMessageSize = 100000000 -[MerkeTree] +[MerkleTree] URI = "zkevm-prover:50061" +MaxThreads = 20 [Log] Environment = "development" diff --git a/tools/datastreamer/main.go b/tools/datastreamer/main.go index 80e692c9d2..44c77c6fd4 100644 --- a/tools/datastreamer/main.go +++ b/tools/datastreamer/main.go @@ -4,7 +4,9 @@ import ( "context" "encoding/binary" "fmt" + "math/big" "os" + "sync" "time" "github.com/0xPolygonHermez/zkevm-data-streamer/datastreamer" @@ -190,10 +192,50 @@ func generate(cliCtx *cli.Context) error { os.Exit(1) } defer stateSqlDB.Close() - stateDB := state.NewPostgresStorage(state.Config{}, stateSqlDB) + stateDBStorage := state.NewPostgresStorage(state.Config{}, stateSqlDB) log.Debug("Connected to the database") - err = state.GenerateDataStreamerFile(cliCtx.Context, streamServer, stateDB, false) + mtDBServerConfig := merkletree.Config{URI: c.MerkleTree.URI} + var mtDBCancel context.CancelFunc + mtDBServiceClient, mtDBClientConn, mtDBCancel := merkletree.NewMTDBServiceClient(cliCtx.Context, mtDBServerConfig) + defer func() { + mtDBCancel() + mtDBClientConn.Close() + }() + stateTree := merkletree.NewStateTree(mtDBServiceClient) + log.Debug("Connected to the merkle tree") + + stateDB := state.NewState(state.Config{}, stateDBStorage, nil, stateTree, nil) + + // Calculate intermediate state roots + var imStateRoots map[uint64][]byte + var imStateRootsMux *sync.Mutex = new(sync.Mutex) + var wg sync.WaitGroup + + lastL2BlockHeader, err := stateDB.GetLastL2BlockHeader(cliCtx.Context, nil) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + maxL2Block := lastL2BlockHeader.Number.Uint64() + imStateRoots = make(map[uint64][]byte, maxL2Block) + + for x := 0; x < c.MerkleTree.MaxThreads; x++ { + start := uint64(x) * (maxL2Block / uint64(c.MerkleTree.MaxThreads)) + end := uint64(x+1)*(maxL2Block/uint64(c.MerkleTree.MaxThreads)) - 1 + + wg.Add(1) + go func(i int) { + defer wg.Done() + log.Debugf("Thread %d: Start: %d, End: %d\n", i, start, end) + getImStateRoots(cliCtx.Context, start, end, &imStateRoots, imStateRootsMux, stateDB, lastL2BlockHeader.Root) + }(x) + } + + wg.Wait() + + err = state.GenerateDataStreamerFile(cliCtx.Context, streamServer, stateDB, false, &imStateRoots) if err != nil { fmt.Printf("Error: %v\n", err) os.Exit(1) @@ -204,6 +246,21 @@ func generate(cliCtx *cli.Context) error { return nil } +func getImStateRoots(ctx context.Context, start, end uint64, isStateRoots *map[uint64][]byte, imStateRootMux *sync.Mutex, stateDB *state.State, stateRoot common.Hash) { + for x := start; x <= end; x++ { + // Populate intermediate state root + position := state.GetSystemSCPosition(x) + imStateRoot, err := stateDB.GetStorageAt(ctx, common.HexToAddress(state.SystemSC), big.NewInt(0).SetBytes(position), stateRoot) + if err != nil { + log.Errorf("Error: %v\n", err) + os.Exit(1) + } + imStateRootMux.Lock() + (*isStateRoots)[x] = imStateRoot.Bytes() + imStateRootMux.Unlock() + } +} + func reprocess(cliCtx *cli.Context) error { c, err := config.Load(cliCtx) if err != nil { @@ -239,7 +296,7 @@ func reprocess(cliCtx *cli.Context) error { if currentL2BlockNumber == 0 { printColored(color.FgHiYellow, "\n\nSetting Genesis block\n\n") - mtDBServerConfig := merkletree.Config{URI: c.MerkeTree.URI} + mtDBServerConfig := merkletree.Config{URI: c.MerkleTree.URI} var mtDBCancel context.CancelFunc mtDBServiceClient, mtDBClientConn, mtDBCancel := merkletree.NewMTDBServiceClient(ctx, mtDBServerConfig) defer func() { @@ -668,6 +725,8 @@ func printEntry(entry datastreamer.FileEntry) { printColored(color.FgHiWhite, fmt.Sprintf("%d\n", dsTx.EffectiveGasPricePercentage)) printColored(color.FgGreen, "Is Valid........: ") printColored(color.FgHiWhite, fmt.Sprintf("%t\n", dsTx.IsValid == 1)) + printColored(color.FgGreen, "State Root......: ") + printColored(color.FgHiWhite, fmt.Sprint(dsTx.StateRoot.Hex()+"\n")) printColored(color.FgGreen, "Encoded Length..: ") printColored(color.FgHiWhite, fmt.Sprintf("%d\n", dsTx.EncodedLength)) printColored(color.FgGreen, "Encoded.........: ")