diff --git a/dot/rpc/modules/dev_integration_test.go b/dot/rpc/modules/dev_integration_test.go index 90c2e60ef2..34c8d0a198 100644 --- a/dot/rpc/modules/dev_integration_test.go +++ b/dot/rpc/modules/dev_integration_test.go @@ -42,8 +42,9 @@ func newState(t *testing.T) (*state.BlockState, *state.EpochState) { db := state.NewInMemoryDB(t) - _, _, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t) - bs, err := state.NewBlockStateFromGenesis(db, genesisHeader, telemetryMock) + _, genesisTrie, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t) + tries := state.NewTries(genesisTrie) + bs, err := state.NewBlockStateFromGenesis(db, tries, genesisHeader, telemetryMock) require.NoError(t, err) es, err := state.NewEpochStateFromGenesis(db, bs, genesisBABEConfig) require.NoError(t, err) diff --git a/dot/state/block.go b/dot/state/block.go index e31d8f40c9..59f3db4521 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -27,8 +27,7 @@ import ( ) const ( - pruneKeyBufferSize = 1000 - blockPrefix = "block" + blockPrefix = "block" ) var ( @@ -60,6 +59,7 @@ type BlockState struct { genesisHash common.Hash lastFinalised common.Hash unfinalisedBlocks *sync.Map // map[common.Hash]*types.Block + tries *Tries // block notifiers imported map[chan *types.Block]struct{} @@ -69,21 +69,19 @@ type BlockState struct { runtimeUpdateSubscriptionsLock sync.RWMutex runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version - pruneKeyCh chan *types.Header - telemetry telemetry.Client } // NewBlockState will create a new BlockState backed by the database located at basePath -func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState, error) { +func NewBlockState(db chaindb.Database, trs *Tries, telemetry telemetry.Client) (*BlockState, error) { bs := &BlockState{ dbPath: db.Path(), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), unfinalisedBlocks: new(sync.Map), + tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), telemetry: telemetry, } @@ -107,16 +105,16 @@ func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState // NewBlockStateFromGenesis initialises a BlockState from a genesis header, // saving it to the database located at basePath -func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header, - telemetryMailer telemetry.Client) (*BlockState, error) { +func NewBlockStateFromGenesis(db chaindb.Database, trs *Tries, header *types.Header, + telemetryMailer telemetry.Client) (*BlockState, error) { // TODO CHECKTEST bs := &BlockState{ bt: blocktree.NewBlockTreeFromRoot(header), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), unfinalisedBlocks: new(sync.Map), + tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), genesisHash: header.Hash(), lastFinalised: header.Hash(), diff --git a/dot/state/block_data_test.go b/dot/state/block_data_test.go index 1a90dd1796..6ca6df58a3 100644 --- a/dot/state/block_data_test.go +++ b/dot/state/block_data_test.go @@ -15,7 +15,7 @@ import ( ) func TestGetSet_ReceiptMessageQueue_Justification(t *testing.T) { - s := newTestBlockState(t, nil) + s := newTestBlockState(t, nil, newTriesEmpty()) require.NotNil(t, s) var genesisHeader = &types.Header{ diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index ac19f469bc..f7fe569191 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -151,8 +151,9 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er continue } + bs.tries.delete(block.Header.StateRoot) + logger.Tracef("pruned block number %s with hash %s", block.Header.Number, hash) - bs.pruneKeyCh <- &block.Header } // if nothing was previously finalised, set the first slot of the network to the @@ -238,8 +239,9 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { continue } + bs.tries.delete(block.Header.StateRoot) + logger.Tracef("cleaned out finalised block from memory; block number %s with hash %s", block.Header.Number, hash) - bs.pruneKeyCh <- &block.Header } return batch.Flush() diff --git a/dot/state/block_finalisation_test.go b/dot/state/block_finalisation_test.go index 20775d9a41..b55466c374 100644 --- a/dot/state/block_finalisation_test.go +++ b/dot/state/block_finalisation_test.go @@ -13,7 +13,7 @@ import ( ) func TestHighestRoundAndSetID(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) round, setID, err := bs.GetHighestRoundAndSetID() require.NoError(t, err) require.Equal(t, uint64(0), round) @@ -61,7 +61,7 @@ func TestHighestRoundAndSetID(t *testing.T) { } func TestBlockState_SetFinalisedHash(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) h, err := bs.GetFinalisedHash(0, 0) require.NoError(t, err) require.Equal(t, testGenesisHeader.Hash(), h) @@ -97,7 +97,7 @@ func TestBlockState_SetFinalisedHash(t *testing.T) { } func TestSetFinalisedHash_setFirstSlotOnFinalisation(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) firstSlot := uint64(42069) digest := types.NewDigest() diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index 7dbd6a9fee..f476e8836e 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -12,13 +12,14 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/runtime" runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/stretchr/testify/require" ) var testMessageTimeout = time.Second * 3 func TestImportChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetImportedBlockNotifierChannel() defer bs.FreeImportedBlockNotifierChannel(ch) @@ -35,7 +36,7 @@ func TestImportChannel(t *testing.T) { } func TestFreeImportedBlockNotifierChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetImportedBlockNotifierChannel() require.Equal(t, 1, len(bs.imported)) @@ -44,7 +45,7 @@ func TestFreeImportedBlockNotifierChannel(t *testing.T) { } func TestFinalizedChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetFinalisedNotifierChannel() @@ -66,7 +67,7 @@ func TestFinalizedChannel(t *testing.T) { } func TestImportChannel_Multi(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) num := 5 chs := make([]chan *types.Block, num) @@ -99,7 +100,7 @@ func TestImportChannel_Multi(t *testing.T) { } func TestFinalizedChannel_Multi(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) num := 5 chs := make([]chan *types.FinalisationInfo, num) @@ -136,7 +137,7 @@ func TestFinalizedChannel_Multi(t *testing.T) { } func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := make(chan<- runtime.Version) chID, err := bs.RegisterRuntimeUpdatedChannel(ch) require.NoError(t, err) @@ -147,7 +148,7 @@ func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) { } func TestService_RegisterUnRegisterConcurrentCalls(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) go func() { for i := 0; i < 100; i++ { diff --git a/dot/state/block_race_test.go b/dot/state/block_race_test.go index 1f1284556c..c5bc2c1bc6 100644 --- a/dot/state/block_race_test.go +++ b/dot/state/block_race_test.go @@ -28,13 +28,15 @@ func TestConcurrencySetHeader(t *testing.T) { dbs[i] = NewInMemoryDB(t) } + tries := NewTries(trie.NewEmptyTrie()) // not used in this test + pend := new(sync.WaitGroup) pend.Add(threads) for i := 0; i < threads; i++ { go func(index int) { defer pend.Done() - bs, err := NewBlockStateFromGenesis(dbs[index], testGenesisHeader, telemetryMock) + bs, err := NewBlockStateFromGenesis(dbs[index], tries, testGenesisHeader, telemetryMock) require.NoError(t, err) header := &types.Header{ diff --git a/dot/state/block_test.go b/dot/state/block_test.go index 50e6513f5f..45508e8a53 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -25,7 +25,7 @@ var testGenesisHeader = &types.Header{ Digest: types.NewDigest(), } -func newTestBlockState(t *testing.T, header *types.Header) *BlockState { +func newTestBlockState(t *testing.T, header *types.Header, tries *Tries) *BlockState { ctrl := gomock.NewController(t) telemetryMock := NewMockClient(ctrl) telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() @@ -35,13 +35,13 @@ func newTestBlockState(t *testing.T, header *types.Header) *BlockState { header = testGenesisHeader } - bs, err := NewBlockStateFromGenesis(db, header, telemetryMock) + bs, err := NewBlockStateFromGenesis(db, tries, header, telemetryMock) require.NoError(t, err) return bs } func TestSetAndGetHeader(t *testing.T) { - bs := newTestBlockState(t, nil) + bs := newTestBlockState(t, nil, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(0), @@ -58,7 +58,7 @@ func TestSetAndGetHeader(t *testing.T) { } func TestHasHeader(t *testing.T) { - bs := newTestBlockState(t, nil) + bs := newTestBlockState(t, nil, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(0), @@ -75,7 +75,7 @@ func TestHasHeader(t *testing.T) { } func TestGetBlockByNumber(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) blockHeader := &types.Header{ ParentHash: testGenesisHeader.Hash(), @@ -97,7 +97,7 @@ func TestGetBlockByNumber(t *testing.T) { } func TestAddBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) // Create header header0 := &types.Header{ @@ -160,7 +160,7 @@ func TestAddBlock(t *testing.T) { } func TestGetSlotForBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) expectedSlot := uint64(77) babeHeader := types.NewBabeDigest() @@ -191,7 +191,7 @@ func TestGetSlotForBlock(t *testing.T) { } func TestIsBlockOnCurrentChain(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) currChain, branchChains := AddBlocksToState(t, bs, 3, false) for _, header := range currChain { @@ -214,7 +214,7 @@ func TestIsBlockOnCurrentChain(t *testing.T) { } func TestAddBlock_BlockNumberToHash(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) currChain, branchChains := AddBlocksToState(t, bs, 8, false) bestHash := bs.BestBlockHash() @@ -262,7 +262,7 @@ func TestAddBlock_BlockNumberToHash(t *testing.T) { } func TestFinalization_DeleteBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) AddBlocksToState(t, bs, 5, false) btBefore := bs.bt.DeepCopy() @@ -317,7 +317,7 @@ func TestFinalization_DeleteBlock(t *testing.T) { } func TestGetHashByNumber(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) res, err := bs.GetHashByNumber(big.NewInt(0)) require.NoError(t, err) @@ -344,7 +344,7 @@ func TestGetHashByNumber(t *testing.T) { func TestAddBlock_WithReOrg(t *testing.T) { t.Skip() // TODO: this should be fixed after state refactor PR - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) header1a := &types.Header{ Number: big.NewInt(1), @@ -453,7 +453,7 @@ func TestAddBlock_WithReOrg(t *testing.T) { } func TestAddBlockToBlockTree(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(1), @@ -473,7 +473,7 @@ func TestAddBlockToBlockTree(t *testing.T) { } func TestNumberIsFinalised(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) fin, err := bs.NumberIsFinalised(big.NewInt(0)) require.NoError(t, err) require.True(t, fin) diff --git a/dot/state/epoch_test.go b/dot/state/epoch_test.go index 52c25386bf..de3467f7a2 100644 --- a/dot/state/epoch_test.go +++ b/dot/state/epoch_test.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" "github.com/ChainSafe/gossamer/lib/keystore" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/require" @@ -28,7 +29,8 @@ var genesisBABEConfig = &types.BabeConfiguration{ func newEpochStateFromGenesis(t *testing.T) *EpochState { db := NewInMemoryDB(t) - s, err := NewEpochStateFromGenesis(db, newTestBlockState(t, nil), genesisBABEConfig) + blockState := newTestBlockState(t, nil, NewTries(trie.NewEmptyTrie())) + s, err := NewEpochStateFromGenesis(db, blockState, genesisBABEConfig) require.NoError(t, err) return s } @@ -184,7 +186,7 @@ func TestEpochState_SetAndGetSlotDuration(t *testing.T) { func TestEpochState_GetEpochFromTime(t *testing.T) { s := newEpochStateFromGenesis(t) - s.blockState = newTestBlockState(t, testGenesisHeader) + s.blockState = newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) epochDuration, err := time.ParseDuration( fmt.Sprintf("%dms", diff --git a/dot/state/helpers_test.go b/dot/state/helpers_test.go index b7057ccc62..68d6246905 100644 --- a/dot/state/helpers_test.go +++ b/dot/state/helpers_test.go @@ -8,9 +8,17 @@ import ( "testing" "time" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/stretchr/testify/require" ) +func newTriesEmpty() *Tries { + return &Tries{ + rootToTrie: make(map[common.Hash]*trie.Trie), + } +} + // newGenerator creates a new PRNG seeded with the // unix nanoseconds value of the current time. func newGenerator() (prng *rand.Rand) { diff --git a/dot/state/initialize.go b/dot/state/initialize.go index 782ebe89e7..7b14afb6f2 100644 --- a/dot/state/initialize.go +++ b/dot/state/initialize.go @@ -62,14 +62,16 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie return fmt.Errorf("failed to write genesis values to database: %s", err) } + tries := NewTries(t) + // create block state from genesis block - blockState, err := NewBlockStateFromGenesis(db, header, s.Telemetry) + blockState, err := NewBlockStateFromGenesis(db, tries, header, s.Telemetry) if err != nil { return fmt.Errorf("failed to create block state from genesis: %s", err) } // create storage state from genesis trie - storageState, err := NewStorageState(db, blockState, t, pruner.Config{}) + storageState, err := NewStorageState(db, blockState, tries, pruner.Config{}) if err != nil { return fmt.Errorf("failed to create storage state from trie: %s", err) } diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index eedc26d7ec..ee45e4d612 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -41,9 +41,11 @@ func NewOfflinePruner(inputDBPath, prunedDBPath string, bloomSize uint64, return nil, fmt.Errorf("failed to load DB %w", err) } + tries := NewTries(trie.NewEmptyTrie()) + // create blockState state // NewBlockState on pruner execution does not use telemetry - blockState, err := NewBlockState(db, nil) + blockState, err := NewBlockState(db, tries, nil) if err != nil { return nil, fmt.Errorf("failed to create block state: %w", err) } @@ -60,7 +62,7 @@ func NewOfflinePruner(inputDBPath, prunedDBPath string, bloomSize uint64, } // load storage state - storageState, err := NewStorageState(db, blockState, trie.NewEmptyTrie(), pruner.Config{}) + storageState, err := NewStorageState(db, blockState, tries, pruner.Config{}) if err != nil { return nil, fmt.Errorf("failed to create new storage state %w", err) } diff --git a/dot/state/service.go b/dot/state/service.go index 48c03f99b7..75b4ae587f 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -114,9 +114,11 @@ func (s *Service) Start() error { return nil } + tries := NewTries(trie.NewEmptyTrie()) + var err error // create block state - s.Block, err = NewBlockState(s.db, s.Telemetry) + s.Block, err = NewBlockState(s.db, tries, s.Telemetry) if err != nil { return fmt.Errorf("failed to create block state: %w", err) } @@ -136,7 +138,7 @@ func (s *Service) Start() error { } // create storage state - s.Storage, err = NewStorageState(s.db, s.Block, trie.NewEmptyTrie(), pr) + s.Storage, err = NewStorageState(s.db, s.Block, tries, pr) if err != nil { return fmt.Errorf("failed to create storage state: %w", err) } @@ -167,9 +169,6 @@ func (s *Service) Start() error { ", highest number " + num.String() + " and genesis hash " + s.Block.genesisHash.String()) - // Start background goroutine to GC pruned keys. - go s.Storage.pruneStorage(s.closeCh) - return nil } diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 897bb96066..f8cabd4c08 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -289,7 +289,7 @@ func TestService_PruneStorage(t *testing.T) { time.Sleep(1 * time.Second) for _, v := range prunedArr { - tr := serv.Storage.tries.get(v.hash) + tr := serv.Storage.blockState.tries.get(v.hash) require.Nil(t, tr) } } diff --git a/dot/state/storage.go b/dot/state/storage.go index 94772187df..4571c9279b 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -30,7 +30,7 @@ func errTrieDoesNotExist(hash common.Hash) error { // StorageState is the struct that holds the trie, db and lock type StorageState struct { blockState *BlockState - tries *tries + tries *Tries db chaindb.Database sync.RWMutex @@ -41,19 +41,14 @@ type StorageState struct { pruner pruner.Pruner } -// NewStorageState creates a new StorageState backed by the given trie and database located at basePath. +// NewStorageState creates a new StorageState backed by the given block state +// and database located at basePath. func NewStorageState(db chaindb.Database, blockState *BlockState, - t *trie.Trie, onlinePruner pruner.Config) (*StorageState, error) { + tries *Tries, onlinePruner pruner.Config) (*StorageState, error) { if db == nil { return nil, fmt.Errorf("cannot have nil database") } - if t == nil { - return nil, fmt.Errorf("cannot have nil trie") - } - - tries := newTries(t) - storageTable := chaindb.NewTable(db, storagePrefix) var p pruner.Pruner @@ -76,11 +71,6 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, }, nil } -func (s *StorageState) pruneKey(keyHeader *types.Header) { - logger.Tracef("pruning trie, number=%d hash=%s", keyHeader.Number, keyHeader.Hash()) - s.tries.delete(keyHeader.StateRoot) -} - // StoreTrie stores the given trie in the StorageState and writes it to the database func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { root := ts.MustRoot() @@ -314,14 +304,3 @@ func (s *StorageState) LoadCodeHash(hash *common.Hash) (common.Hash, error) { func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) { return trie.GenerateProof(stateRoot[:], keys, s.db) } - -func (s *StorageState) pruneStorage(closeCh chan interface{}) { - for { - select { - case key := <-s.blockState.pruneKeyCh: - s.pruneKey(key) - case <-closeCh: - return - } - } -} diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index b455a66231..cc0c756277 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -23,9 +23,11 @@ import ( func newTestStorageState(t *testing.T) *StorageState { db := NewInMemoryDB(t) - bs := newTestBlockState(t, testGenesisHeader) + tries := newTriesEmpty() - s, err := NewStorageState(db, bs, trie.NewEmptyTrie(), pruner.Config{}) + bs := newTestBlockState(t, testGenesisHeader, tries) + + s, err := NewStorageState(db, bs, tries, pruner.Config{}) require.NoError(t, err) return s } @@ -99,7 +101,7 @@ func TestStorage_TrieState(t *testing.T) { time.Sleep(time.Millisecond * 100) // get trie from db - storage.tries.delete(root) + storage.blockState.tries.delete(root) ts3, err := storage.TrieState(&root) require.NoError(t, err) require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash()) @@ -131,19 +133,19 @@ func TestStorage_LoadFromDB(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.delete(root) + storage.blockState.tries.delete(root) data, err := storage.GetStorage(&root, trieKV[0].key) require.NoError(t, err) require.Equal(t, trieKV[0].value, data) - storage.tries.delete(root) + storage.blockState.tries.delete(root) prefixKeys, err := storage.GetKeysWithPrefix(&root, []byte("ke")) require.NoError(t, err) require.Equal(t, 2, len(prefixKeys)) - storage.tries.delete(root) + storage.blockState.tries.delete(root) entries, err := storage.Entries(&root) require.NoError(t, err) @@ -161,7 +163,7 @@ func TestStorage_StoreTrie_NotSyncing(t *testing.T) { err = storage.StoreTrie(ts, nil) require.NoError(t, err) - require.Equal(t, 2, storage.tries.len()) + require.Equal(t, 2, storage.blockState.tries.len()) } func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { @@ -179,16 +181,18 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { "0", )) - blockState, err := NewBlockStateFromGenesis(db, genHeader, telemetryMock) - require.NoError(t, err) - testChildTrie := trie.NewEmptyTrie() testChildTrie.Put([]byte("keyInsidechild"), []byte("voila")) err = genTrie.PutChild([]byte("keyToChild"), testChildTrie) require.NoError(t, err) - storage, err := NewStorageState(db, blockState, genTrie, pruner.Config{}) + tries := NewTries(genTrie) + + blockState, err := NewBlockStateFromGenesis(db, tries, genHeader, telemetryMock) + require.NoError(t, err) + + storage, err := NewStorageState(db, blockState, tries, pruner.Config{}) require.NoError(t, err) trieState, err := runtime.NewTrieState(genTrie) @@ -208,7 +212,7 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.delete(rootHash) + storage.blockState.tries.delete(rootHash) _, err = storage.GetStorageChild(&rootHash, []byte("keyToChild")) require.NoError(t, err) diff --git a/dot/state/tries.go b/dot/state/tries.go index e7afd3dbb1..21342f4d67 100644 --- a/dot/state/tries.go +++ b/dot/state/tries.go @@ -10,13 +10,17 @@ import ( "github.com/ChainSafe/gossamer/lib/trie" ) -type tries struct { +// Tries is a thread safe map of root hash +// to trie. +type Tries struct { rootToTrie map[common.Hash]*trie.Trie mapMutex sync.RWMutex } -func newTries(t *trie.Trie) *tries { - return &tries{ +// NewTries creates a new thread safe map of root hash +// to trie using the trie given as a first trie. +func NewTries(t *trie.Trie) *Tries { + return &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ t.MustHash(): t, }, @@ -25,7 +29,7 @@ func newTries(t *trie.Trie) *tries { // softSet sets the given trie at the given root hash // in the memory map only if it is not already set. -func (t *tries) softSet(root common.Hash, trie *trie.Trie) { +func (t *Tries) softSet(root common.Hash, trie *trie.Trie) { t.mapMutex.Lock() defer t.mapMutex.Unlock() @@ -37,7 +41,7 @@ func (t *tries) softSet(root common.Hash, trie *trie.Trie) { t.rootToTrie[root] = trie } -func (t *tries) delete(root common.Hash) { +func (t *Tries) delete(root common.Hash) { t.mapMutex.Lock() defer t.mapMutex.Unlock() delete(t.rootToTrie, root) @@ -45,7 +49,7 @@ func (t *tries) delete(root common.Hash) { // get retrieves the trie corresponding to the root hash given // from the in-memory thread safe map. -func (t *tries) get(root common.Hash) (tr *trie.Trie) { +func (t *Tries) get(root common.Hash) (tr *trie.Trie) { t.mapMutex.RLock() defer t.mapMutex.RUnlock() return t.rootToTrie[root] @@ -53,7 +57,7 @@ func (t *tries) get(root common.Hash) (tr *trie.Trie) { // len returns the current numbers of tries // stored in the in-memory map. -func (t *tries) len() int { +func (t *Tries) len() int { t.mapMutex.RLock() defer t.mapMutex.RUnlock() return len(t.rootToTrie) diff --git a/dot/state/tries_test.go b/dot/state/tries_test.go index 0a0bc1d865..5e3b90e7c1 100644 --- a/dot/state/tries_test.go +++ b/dot/state/tries_test.go @@ -12,14 +12,14 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_newTries(t *testing.T) { +func Test_NewTries(t *testing.T) { t.Parallel() tr := trie.NewEmptyTrie() - rootToTrie := newTries(tr) + rootToTrie := NewTries(tr) - expectedTries := &tries{ + expectedTries := &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ tr.MustHash(): tr, }, @@ -28,36 +28,36 @@ func Test_newTries(t *testing.T) { assert.Equal(t, expectedTries, rootToTrie) } -func Test_tries_softSet(t *testing.T) { +func Test_Tries_softSet(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash trie *trie.Trie - expectedTries *tries + expectedTries *Tries }{ "set new in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, trie: trie.NewEmptyTrie(), - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: trie.NewEmptyTrie(), }, }, }, "do not override in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, }, root: common.Hash{1, 2, 3}, trie: trie.NewEmptyTrie(), - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, @@ -77,31 +77,31 @@ func Test_tries_softSet(t *testing.T) { } } -func Test_tries_delete(t *testing.T) { +func Test_Tries_delete(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash - expectedTries *tries + expectedTries *Tries }{ "not found": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, "deleted": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, }, root: common.Hash{1, 2, 3}, - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, @@ -118,16 +118,16 @@ func Test_tries_delete(t *testing.T) { }) } } -func Test_tries_get(t *testing.T) { +func Test_Tries_get(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash trie *trie.Trie }{ "found in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: trie.NewTrie(&node.Leaf{ Key: []byte{1, 2, 3}, @@ -141,7 +141,7 @@ func Test_tries_get(t *testing.T) { }, "not found in map": { // similar to not found in database - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, @@ -160,20 +160,20 @@ func Test_tries_get(t *testing.T) { } } -func Test_tries_len(t *testing.T) { +func Test_Tries_len(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries length int }{ "empty map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, "non empty map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, diff --git a/lib/grandpa/grandpa_test.go b/lib/grandpa/grandpa_test.go index f52c86f7a1..14a60f0714 100644 --- a/lib/grandpa/grandpa_test.go +++ b/lib/grandpa/grandpa_test.go @@ -60,7 +60,8 @@ func newTestState(t *testing.T) *state.Service { t.Cleanup(func() { db.Close() }) _, genTrie, _ := genesis.NewTestGenesisWithTrieAndHeader(t) - block, err := state.NewBlockStateFromGenesis(db, testGenesisHeader, telemetryMock) + tries := state.NewTries(genTrie) + block, err := state.NewBlockStateFromGenesis(db, tries, testGenesisHeader, telemetryMock) require.NoError(t, err) rtCfg := &wasmer.Config{} @@ -862,7 +863,6 @@ func TestFindParentWithNumber(t *testing.T) { p, err := gs.findParentWithNumber(v, 1) require.NoError(t, err) - t.Log(st.Block.BlocktreeAsString()) expected, err := st.Block.GetBlockByNumber(big.NewInt(1)) require.NoError(t, err)