Skip to content

Commit

Permalink
fix(dot/state): inject mutex protected tries to states (#2287)
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 authored Feb 16, 2022
1 parent 63303a3 commit 67a9bbb
Show file tree
Hide file tree
Showing 19 changed files with 126 additions and 122 deletions.
5 changes: 3 additions & 2 deletions dot/rpc/modules/dev_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ func newState(t *testing.T) (*state.BlockState, *state.EpochState) {

db := state.NewInMemoryDB(t)

_, _, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t)
bs, err := state.NewBlockStateFromGenesis(db, genesisHeader, telemetryMock)
_, genesisTrie, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t)
tries := state.NewTries(genesisTrie)
bs, err := state.NewBlockStateFromGenesis(db, tries, genesisHeader, telemetryMock)
require.NoError(t, err)
es, err := state.NewEpochStateFromGenesis(db, bs, genesisBABEConfig)
require.NoError(t, err)
Expand Down
16 changes: 7 additions & 9 deletions dot/state/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import (
)

const (
pruneKeyBufferSize = 1000
blockPrefix = "block"
blockPrefix = "block"
)

var (
Expand Down Expand Up @@ -60,6 +59,7 @@ type BlockState struct {
genesisHash common.Hash
lastFinalised common.Hash
unfinalisedBlocks *sync.Map // map[common.Hash]*types.Block
tries *Tries

// block notifiers
imported map[chan *types.Block]struct{}
Expand All @@ -69,21 +69,19 @@ type BlockState struct {
runtimeUpdateSubscriptionsLock sync.RWMutex
runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version

pruneKeyCh chan *types.Header

telemetry telemetry.Client
}

// NewBlockState will create a new BlockState backed by the database located at basePath
func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState, error) {
func NewBlockState(db chaindb.Database, trs *Tries, telemetry telemetry.Client) (*BlockState, error) {
bs := &BlockState{
dbPath: db.Path(),
baseState: NewBaseState(db),
db: chaindb.NewTable(db, blockPrefix),
unfinalisedBlocks: new(sync.Map),
tries: trs,
imported: make(map[chan *types.Block]struct{}),
finalised: make(map[chan *types.FinalisationInfo]struct{}),
pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize),
runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version),
telemetry: telemetry,
}
Expand All @@ -107,16 +105,16 @@ func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState

// NewBlockStateFromGenesis initialises a BlockState from a genesis header,
// saving it to the database located at basePath
func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header,
telemetryMailer telemetry.Client) (*BlockState, error) {
func NewBlockStateFromGenesis(db chaindb.Database, trs *Tries, header *types.Header,
telemetryMailer telemetry.Client) (*BlockState, error) { // TODO CHECKTEST
bs := &BlockState{
bt: blocktree.NewBlockTreeFromRoot(header),
baseState: NewBaseState(db),
db: chaindb.NewTable(db, blockPrefix),
unfinalisedBlocks: new(sync.Map),
tries: trs,
imported: make(map[chan *types.Block]struct{}),
finalised: make(map[chan *types.FinalisationInfo]struct{}),
pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize),
runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version),
genesisHash: header.Hash(),
lastFinalised: header.Hash(),
Expand Down
2 changes: 1 addition & 1 deletion dot/state/block_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

func TestGetSet_ReceiptMessageQueue_Justification(t *testing.T) {
s := newTestBlockState(t, nil)
s := newTestBlockState(t, nil, newTriesEmpty())
require.NotNil(t, s)

var genesisHeader = &types.Header{
Expand Down
6 changes: 4 additions & 2 deletions dot/state/block_finalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er
continue
}

bs.tries.delete(block.Header.StateRoot)

logger.Tracef("pruned block number %s with hash %s", block.Header.Number, hash)
bs.pruneKeyCh <- &block.Header
}

// if nothing was previously finalised, set the first slot of the network to the
Expand Down Expand Up @@ -238,8 +239,9 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error {
continue
}

bs.tries.delete(block.Header.StateRoot)

logger.Tracef("cleaned out finalised block from memory; block number %s with hash %s", block.Header.Number, hash)
bs.pruneKeyCh <- &block.Header
}

return batch.Flush()
Expand Down
6 changes: 3 additions & 3 deletions dot/state/block_finalisation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func TestHighestRoundAndSetID(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
round, setID, err := bs.GetHighestRoundAndSetID()
require.NoError(t, err)
require.Equal(t, uint64(0), round)
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestHighestRoundAndSetID(t *testing.T) {
}

func TestBlockState_SetFinalisedHash(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
h, err := bs.GetFinalisedHash(0, 0)
require.NoError(t, err)
require.Equal(t, testGenesisHeader.Hash(), h)
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestBlockState_SetFinalisedHash(t *testing.T) {
}

func TestSetFinalisedHash_setFirstSlotOnFinalisation(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
firstSlot := uint64(42069)

digest := types.NewDigest()
Expand Down
15 changes: 8 additions & 7 deletions dot/state/block_notify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/runtime"
runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/stretchr/testify/require"
)

var testMessageTimeout = time.Second * 3

func TestImportChannel(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))
ch := bs.GetImportedBlockNotifierChannel()

defer bs.FreeImportedBlockNotifierChannel(ch)
Expand All @@ -35,7 +36,7 @@ func TestImportChannel(t *testing.T) {
}

func TestFreeImportedBlockNotifierChannel(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))
ch := bs.GetImportedBlockNotifierChannel()
require.Equal(t, 1, len(bs.imported))

Expand All @@ -44,7 +45,7 @@ func TestFreeImportedBlockNotifierChannel(t *testing.T) {
}

func TestFinalizedChannel(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))

ch := bs.GetFinalisedNotifierChannel()

Expand All @@ -66,7 +67,7 @@ func TestFinalizedChannel(t *testing.T) {
}

func TestImportChannel_Multi(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))

num := 5
chs := make([]chan *types.Block, num)
Expand Down Expand Up @@ -99,7 +100,7 @@ func TestImportChannel_Multi(t *testing.T) {
}

func TestFinalizedChannel_Multi(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))

num := 5
chs := make([]chan *types.FinalisationInfo, num)
Expand Down Expand Up @@ -136,7 +137,7 @@ func TestFinalizedChannel_Multi(t *testing.T) {
}

func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))
ch := make(chan<- runtime.Version)
chID, err := bs.RegisterRuntimeUpdatedChannel(ch)
require.NoError(t, err)
Expand All @@ -147,7 +148,7 @@ func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) {
}

func TestService_RegisterUnRegisterConcurrentCalls(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))

go func() {
for i := 0; i < 100; i++ {
Expand Down
4 changes: 3 additions & 1 deletion dot/state/block_race_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ func TestConcurrencySetHeader(t *testing.T) {
dbs[i] = NewInMemoryDB(t)
}

tries := NewTries(trie.NewEmptyTrie()) // not used in this test

pend := new(sync.WaitGroup)
pend.Add(threads)
for i := 0; i < threads; i++ {
go func(index int) {
defer pend.Done()

bs, err := NewBlockStateFromGenesis(dbs[index], testGenesisHeader, telemetryMock)
bs, err := NewBlockStateFromGenesis(dbs[index], tries, testGenesisHeader, telemetryMock)
require.NoError(t, err)

header := &types.Header{
Expand Down
28 changes: 14 additions & 14 deletions dot/state/block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var testGenesisHeader = &types.Header{
Digest: types.NewDigest(),
}

func newTestBlockState(t *testing.T, header *types.Header) *BlockState {
func newTestBlockState(t *testing.T, header *types.Header, tries *Tries) *BlockState {
ctrl := gomock.NewController(t)
telemetryMock := NewMockClient(ctrl)
telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes()
Expand All @@ -35,13 +35,13 @@ func newTestBlockState(t *testing.T, header *types.Header) *BlockState {
header = testGenesisHeader
}

bs, err := NewBlockStateFromGenesis(db, header, telemetryMock)
bs, err := NewBlockStateFromGenesis(db, tries, header, telemetryMock)
require.NoError(t, err)
return bs
}

func TestSetAndGetHeader(t *testing.T) {
bs := newTestBlockState(t, nil)
bs := newTestBlockState(t, nil, newTriesEmpty())

header := &types.Header{
Number: big.NewInt(0),
Expand All @@ -58,7 +58,7 @@ func TestSetAndGetHeader(t *testing.T) {
}

func TestHasHeader(t *testing.T) {
bs := newTestBlockState(t, nil)
bs := newTestBlockState(t, nil, newTriesEmpty())

header := &types.Header{
Number: big.NewInt(0),
Expand All @@ -75,7 +75,7 @@ func TestHasHeader(t *testing.T) {
}

func TestGetBlockByNumber(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())

blockHeader := &types.Header{
ParentHash: testGenesisHeader.Hash(),
Expand All @@ -97,7 +97,7 @@ func TestGetBlockByNumber(t *testing.T) {
}

func TestAddBlock(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())

// Create header
header0 := &types.Header{
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestAddBlock(t *testing.T) {
}

func TestGetSlotForBlock(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
expectedSlot := uint64(77)

babeHeader := types.NewBabeDigest()
Expand Down Expand Up @@ -191,7 +191,7 @@ func TestGetSlotForBlock(t *testing.T) {
}

func TestIsBlockOnCurrentChain(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
currChain, branchChains := AddBlocksToState(t, bs, 3, false)

for _, header := range currChain {
Expand All @@ -214,7 +214,7 @@ func TestIsBlockOnCurrentChain(t *testing.T) {
}

func TestAddBlock_BlockNumberToHash(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
currChain, branchChains := AddBlocksToState(t, bs, 8, false)

bestHash := bs.BestBlockHash()
Expand Down Expand Up @@ -262,7 +262,7 @@ func TestAddBlock_BlockNumberToHash(t *testing.T) {
}

func TestFinalization_DeleteBlock(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
AddBlocksToState(t, bs, 5, false)

btBefore := bs.bt.DeepCopy()
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestFinalization_DeleteBlock(t *testing.T) {
}

func TestGetHashByNumber(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())

res, err := bs.GetHashByNumber(big.NewInt(0))
require.NoError(t, err)
Expand All @@ -344,7 +344,7 @@ func TestGetHashByNumber(t *testing.T) {

func TestAddBlock_WithReOrg(t *testing.T) {
t.Skip() // TODO: this should be fixed after state refactor PR
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())

header1a := &types.Header{
Number: big.NewInt(1),
Expand Down Expand Up @@ -453,7 +453,7 @@ func TestAddBlock_WithReOrg(t *testing.T) {
}

func TestAddBlockToBlockTree(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())

header := &types.Header{
Number: big.NewInt(1),
Expand All @@ -473,7 +473,7 @@ func TestAddBlockToBlockTree(t *testing.T) {
}

func TestNumberIsFinalised(t *testing.T) {
bs := newTestBlockState(t, testGenesisHeader)
bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty())
fin, err := bs.NumberIsFinalised(big.NewInt(0))
require.NoError(t, err)
require.True(t, fin)
Expand Down
6 changes: 4 additions & 2 deletions dot/state/epoch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/crypto/sr25519"
"github.com/ChainSafe/gossamer/lib/keystore"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/ChainSafe/gossamer/pkg/scale"

"github.com/stretchr/testify/require"
Expand All @@ -28,7 +29,8 @@ var genesisBABEConfig = &types.BabeConfiguration{

func newEpochStateFromGenesis(t *testing.T) *EpochState {
db := NewInMemoryDB(t)
s, err := NewEpochStateFromGenesis(db, newTestBlockState(t, nil), genesisBABEConfig)
blockState := newTestBlockState(t, nil, NewTries(trie.NewEmptyTrie()))
s, err := NewEpochStateFromGenesis(db, blockState, genesisBABEConfig)
require.NoError(t, err)
return s
}
Expand Down Expand Up @@ -184,7 +186,7 @@ func TestEpochState_SetAndGetSlotDuration(t *testing.T) {

func TestEpochState_GetEpochFromTime(t *testing.T) {
s := newEpochStateFromGenesis(t)
s.blockState = newTestBlockState(t, testGenesisHeader)
s.blockState = newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie()))

epochDuration, err := time.ParseDuration(
fmt.Sprintf("%dms",
Expand Down
8 changes: 8 additions & 0 deletions dot/state/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@ import (
"testing"
"time"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/stretchr/testify/require"
)

func newTriesEmpty() *Tries {
return &Tries{
rootToTrie: make(map[common.Hash]*trie.Trie),
}
}

// newGenerator creates a new PRNG seeded with the
// unix nanoseconds value of the current time.
func newGenerator() (prng *rand.Rand) {
Expand Down
Loading

0 comments on commit 67a9bbb

Please sign in to comment.