Skip to content

Commit

Permalink
fix(dot/core): fix the race condition in TrieState (#2499)
Browse files Browse the repository at this point in the history
Passing nil hash argument in `TrieState(nil)` returns `TrieState(stateRoot hash of best block)` and `GetRuntime(nil)` returns `GetRuntime(best block hash)`. 
In `HandleSubmittedExtrinsic`, It is likely that in the time between those function calls, a block has been imported, in which case blocks used by them would not match, which is not what we expect. This commit fixes that by getting the best block hash first and using them, instead of passing nil.

Fixes #2402
  • Loading branch information
kishansagathiya authored May 3, 2022
1 parent 014629d commit 804069c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
12 changes: 10 additions & 2 deletions dot/core/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"sync"

"github.com/ChainSafe/gossamer/dot/network"
Expand Down Expand Up @@ -495,12 +496,19 @@ func (s *Service) HandleSubmittedExtrinsic(ext types.Extrinsic) error {
return nil
}

ts, err := s.storageState.TrieState(nil)
bestBlockHash := s.blockState.BestBlockHash()

stateRoot, err := s.storageState.GetStateRootFromBlock(&bestBlockHash)
if err != nil {
return fmt.Errorf("could not get state root from block %s: %w", bestBlockHash, err)
}

ts, err := s.storageState.TrieState(stateRoot)
if err != nil {
return err
}

rt, err := s.blockState.GetRuntime(nil)
rt, err := s.blockState.GetRuntime(&bestBlockHash)
if err != nil {
logger.Critical("failed to get runtime")
return err
Expand Down
42 changes: 31 additions & 11 deletions dot/core/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,10 +1068,15 @@ func TestServiceHandleSubmittedExtrinsic(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(nil).Return(nil, errDummyErr)
mockStorageState.EXPECT().TrieState(&common.Hash{}).Return(nil, errDummyErr)
mockStorageState.EXPECT().GetStateRootFromBlock(&common.Hash{}).Return(&common.Hash{}, nil)

mockBlockState := NewMockBlockState(ctrl)
mockBlockState.EXPECT().BestBlockHash().Return(common.Hash{})
mockTxnState := NewMockTransactionState(ctrl)
mockTxnState.EXPECT().Exists(nil)
service := &Service{
blockState: mockBlockState,
storageState: mockStorageState,
transactionState: mockTxnState,
net: NewMockNetwork(ctrl),
Expand All @@ -1082,10 +1087,15 @@ func TestServiceHandleSubmittedExtrinsic(t *testing.T) {
t.Run("get runtime err", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(nil).Return(&rtstorage.TrieState{}, nil)

mockBlockState := NewMockBlockState(ctrl)
mockBlockState.EXPECT().GetRuntime(nil).Return(nil, errDummyErr)
mockBlockState.EXPECT().BestBlockHash().Return(common.Hash{})
mockBlockState.EXPECT().GetRuntime(&common.Hash{}).Return(nil, errDummyErr)

mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(&common.Hash{}).Return(&rtstorage.TrieState{}, nil)
mockStorageState.EXPECT().GetStateRootFromBlock(&common.Hash{}).Return(&common.Hash{}, nil)

mockTxnState := NewMockTransactionState(ctrl)
mockTxnState.EXPECT().Exists(nil).MaxTimes(2)
service := &Service{
Expand All @@ -1100,13 +1110,18 @@ func TestServiceHandleSubmittedExtrinsic(t *testing.T) {
t.Run("validate txn err", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockBlockState := NewMockBlockState(ctrl)
mockBlockState.EXPECT().BestBlockHash().Return(common.Hash{})
runtimeMockErr := new(mocksruntime.Instance)
mockBlockState.EXPECT().GetRuntime(&common.Hash{}).Return(runtimeMockErr, nil).MaxTimes(2)

mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(nil).Return(&rtstorage.TrieState{}, nil)
mockStorageState.EXPECT().TrieState(&common.Hash{}).Return(&rtstorage.TrieState{}, nil)
mockStorageState.EXPECT().GetStateRootFromBlock(&common.Hash{}).Return(&common.Hash{}, nil)

mockTxnState := NewMockTransactionState(ctrl)
mockTxnState.EXPECT().Exists(types.Extrinsic{})
runtimeMockErr := new(mocksruntime.Instance)
mockBlockState := NewMockBlockState(ctrl)
mockBlockState.EXPECT().GetRuntime(nil).Return(runtimeMockErr, nil).MaxTimes(2)

runtimeMockErr.On("SetContextStorage", &rtstorage.TrieState{})
runtimeMockErr.On("ValidateTransaction", externalExt).Return(nil, errDummyErr)
service := &Service{
Expand All @@ -1121,14 +1136,19 @@ func TestServiceHandleSubmittedExtrinsic(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(nil).Return(&rtstorage.TrieState{}, nil)

runtimeMock := new(mocksruntime.Instance)
mockBlockState := NewMockBlockState(ctrl)
mockBlockState.EXPECT().GetRuntime(nil).Return(runtimeMock, nil).MaxTimes(2)
mockBlockState.EXPECT().BestBlockHash().Return(common.Hash{})
mockBlockState.EXPECT().GetRuntime(&common.Hash{}).Return(runtimeMock, nil).MaxTimes(2)
runtimeMock.On("SetContextStorage", &rtstorage.TrieState{})
runtimeMock.On("ValidateTransaction", externalExt).
Return(&transaction.Validity{Propagate: true}, nil)

mockStorageState := NewMockStorageState(ctrl)
mockStorageState.EXPECT().TrieState(&common.Hash{}).Return(&rtstorage.TrieState{}, nil)
mockStorageState.EXPECT().GetStateRootFromBlock(&common.Hash{}).Return(&common.Hash{}, nil)

mockTxnState := NewMockTransactionState(ctrl)
mockTxnState.EXPECT().Exists(types.Extrinsic{}).MaxTimes(2)
mockTxnState.EXPECT().AddToPool(transaction.NewValidTransaction(ext, &transaction.Validity{Propagate: true}))
Expand Down
2 changes: 1 addition & 1 deletion dot/rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ func TestHTTPServer_ServeHTTP(t *testing.T) {

_, message, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, item.expected, message)
require.Equal(t, string(item.expected), string(message))
}
}

0 comments on commit 804069c

Please sign in to comment.