diff --git a/lib/blocktree/blocktree.go b/lib/blocktree/blocktree.go index 5f8df858e9..f3aaf88da7 100644 --- a/lib/blocktree/blocktree.go +++ b/lib/blocktree/blocktree.go @@ -23,17 +23,15 @@ type BlockTree struct { root *node leaves *leafMap sync.RWMutex - nodeCache map[Hash]*node - runtime *sync.Map + runtime *sync.Map // map[Hash]runtime.Instance } // NewEmptyBlockTree creates a BlockTree with a nil head func NewEmptyBlockTree() *BlockTree { return &BlockTree{ - root: nil, - leaves: newEmptyLeafMap(), - nodeCache: make(map[Hash]*node), - runtime: &sync.Map{}, // map[Hash]runtime.Instance + root: nil, + leaves: newEmptyLeafMap(), + runtime: &sync.Map{}, } } @@ -49,20 +47,12 @@ func NewBlockTreeFromRoot(root *types.Header) *BlockTree { } return &BlockTree{ - root: n, - leaves: newLeafMap(n), - nodeCache: make(map[Hash]*node), - runtime: &sync.Map{}, + root: n, + leaves: newLeafMap(n), + runtime: &sync.Map{}, } } -// GenesisHash returns the hash of the genesis block -func (bt *BlockTree) GenesisHash() Hash { - bt.RLock() - defer bt.RUnlock() - return bt.root.hash -} - // AddBlock inserts the block as child of its parent node // Note: Assumes block has no children func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) error { @@ -75,8 +65,7 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) error } // Check if it already exists - n := bt.getNode(header.Hash()) - if n != nil { + if n := bt.getNode(header.Hash()); n != nil { return ErrBlockExists } @@ -87,17 +76,16 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) error return errUnexpectedNumber } - n = &node{ + n := &node{ hash: header.Hash(), parent: parent, children: []*node{}, number: number, arrivalTime: arrivalTime, } + parent.addChild(n) bt.leaves.replace(parent, n) - bt.setInCache(n) - return nil } @@ -121,24 +109,8 @@ func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) (hashes []common.Has return bt.root.getNodesWithNumber(number, hashes) } -func (bt *BlockTree) setInCache(b *node) { - if b == nil { - return - } - - if _, has := bt.nodeCache[b.hash]; !has { - bt.nodeCache[b.hash] = b - } -} - // getNode finds and returns a node based on its Hash. Returns nil if not found. func (bt *BlockTree) getNode(h Hash) (ret *node) { - defer func() { bt.setInCache(ret) }() - - if b, ok := bt.nodeCache[h]; ok { - return b - } - if bt.root.hash == h { return bt.root } @@ -164,12 +136,6 @@ func (bt *BlockTree) getNode(h Hash) (ret *node) { func (bt *BlockTree) Prune(finalised Hash) (pruned []Hash) { bt.Lock() defer bt.Unlock() - defer func() { - for _, hash := range pruned { - delete(bt.nodeCache, hash) - bt.runtime.Delete(hash) - } - }() if finalised == bt.root.hash { return pruned @@ -190,6 +156,10 @@ func (bt *BlockTree) Prune(finalised Hash) (pruned []Hash) { bt.leaves.store(leaf.hash, leaf) } + for _, hash := range pruned { + bt.runtime.Delete(hash) + } + return pruned } @@ -218,7 +188,7 @@ func (bt *BlockTree) String() string { return fmt.Sprintf("%s\n%s\n", metadata, tree.Print()) } -// longestPath returns the path from the root to leftmost deepest leaf in BlockTree BT +// longestPath returns the path from the root to the deepest leaf in the blocktree func (bt *BlockTree) longestPath() []*node { dl := bt.deepestLeaf() var path []*node @@ -260,7 +230,7 @@ func (bt *BlockTree) SubBlockchain(start, end Hash) ([]Hash, error) { } -// deepestLeaf returns the leftmost deepest leaf in the block tree. +// deepestLeaf returns the deepest leaf in the block tree. func (bt *BlockTree) deepestLeaf() *node { return bt.leaves.deepestLeaf() } @@ -392,8 +362,8 @@ func (bt *BlockTree) GetArrivalTime(hash common.Hash) (time.Time, error) { bt.RLock() defer bt.RUnlock() - n, has := bt.nodeCache[hash] - if !has { + n := bt.getNode(hash) + if n == nil { return time.Time{}, ErrNodeNotFound } @@ -405,9 +375,7 @@ func (bt *BlockTree) DeepCopy() *BlockTree { bt.RLock() defer bt.RUnlock() - btCopy := &BlockTree{ - nodeCache: make(map[Hash]*node), - } + btCopy := &BlockTree{} if bt.root == nil { return btCopy @@ -424,10 +392,6 @@ func (bt *BlockTree) DeepCopy() *BlockTree { } } - for hash := range bt.nodeCache { - btCopy.nodeCache[hash] = btCopy.getNode(hash) - } - return btCopy } diff --git a/lib/blocktree/blocktree_test.go b/lib/blocktree/blocktree_test.go index f7fa43150e..f35b8462f4 100644 --- a/lib/blocktree/blocktree_test.go +++ b/lib/blocktree/blocktree_test.go @@ -283,12 +283,7 @@ func TestBlockTree_GetNode(t *testing.T) { } block := bt.getNode(branches[0].hash) - - cachedBlock, ok := bt.nodeCache[block.hash] - require.True(t, len(bt.nodeCache) > 0) - require.True(t, ok) - require.NotNil(t, cachedBlock) - require.Equal(t, cachedBlock, block) + require.NotNil(t, block) } func TestBlockTree_GetAllBlocksAtNumber(t *testing.T) { @@ -458,38 +453,15 @@ func TestBlockTree_Prune(t *testing.T) { } } -func TestBlockTree_PruneCache(t *testing.T) { - var bt *BlockTree - var branches []testBranch - - for { - bt, branches = createTestBlockTree(t, testHeader, 5) - if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 1 { - break - } - } - - // pick some block to finalise - finalised := bt.root.children[0].children[0].children[0] - pruned := bt.Prune(finalised.hash) - - for _, prunedHash := range pruned { - block, ok := bt.nodeCache[prunedHash] - - require.False(t, ok) - require.Nil(t, block) - } -} - func TestBlockTree_GetHashByNumber(t *testing.T) { bt, _ := createTestBlockTree(t, testHeader, 8) best := bt.DeepestBlockHash() - bn := bt.nodeCache[best] + bn := bt.getNode(best) for i := int64(0); i < bn.number.Int64(); i++ { hash, err := bt.GetHashByNumber(big.NewInt(i)) require.NoError(t, err) - require.Equal(t, big.NewInt(i), bt.nodeCache[hash].number) + require.Equal(t, big.NewInt(i), bt.getNode(hash).number) desc, err := bt.IsDescendantOf(hash, best) require.NoError(t, err) require.True(t, desc, fmt.Sprintf("index %d failed, got hash=%s", i, hash)) @@ -506,17 +478,6 @@ func TestBlockTree_DeepCopy(t *testing.T) { bt, _ := createFlatTree(t, 8) btCopy := bt.DeepCopy() - for hash := range bt.nodeCache { - b, ok := btCopy.nodeCache[hash] - b2 := bt.nodeCache[hash] - - require.True(t, ok) - require.True(t, b != b2) - - require.True(t, equalNodeValue(b, b2)) - - } - require.True(t, equalNodeValue(bt.root, btCopy.root), "BlockTree heads not equal") require.True(t, equalLeaves(bt.leaves, btCopy.leaves), "BlockTree leaves not equal") diff --git a/lib/blocktree/leaves.go b/lib/blocktree/leaves.go index cc47bf6725..13d20cfa3f 100644 --- a/lib/blocktree/leaves.go +++ b/lib/blocktree/leaves.go @@ -13,6 +13,7 @@ import ( // leafMap provides quick lookup for existing leaves type leafMap struct { + sync.RWMutex smap *sync.Map // map[common.Hash]*node } @@ -24,8 +25,8 @@ func newEmptyLeafMap() *leafMap { func newLeafMap(n *node) *leafMap { smap := &sync.Map{} - for _, child := range n.getLeaves(nil) { - smap.Store(child.hash, child) + for _, leaf := range n.getLeaves(nil) { + smap.Store(leaf.hash, leaf) } return &leafMap{ @@ -33,12 +34,12 @@ func newLeafMap(n *node) *leafMap { } } -func (ls *leafMap) store(key Hash, value *node) { - ls.smap.Store(key, value) +func (lm *leafMap) store(key Hash, value *node) { + lm.smap.Store(key, value) } -func (ls *leafMap) load(key Hash) (*node, error) { - v, ok := ls.smap.Load(key) +func (lm *leafMap) load(key Hash) (*node, error) { + v, ok := lm.smap.Load(key) if !ok { return nil, errors.New("key not found") } @@ -47,18 +48,23 @@ func (ls *leafMap) load(key Hash) (*node, error) { } // Replace deletes the old node from the map and inserts the new one -func (ls *leafMap) replace(oldNode, newNode *node) { - ls.smap.Delete(oldNode.hash) - ls.store(newNode.hash, newNode) +func (lm *leafMap) replace(oldNode, newNode *node) { + lm.Lock() + defer lm.Unlock() + lm.smap.Delete(oldNode.hash) + lm.store(newNode.hash, newNode) } // DeepestLeaf searches the stored leaves to the find the one with the greatest number. // If there are two leaves with the same number, choose the one with the earliest arrival time. -func (ls *leafMap) deepestLeaf() *node { +func (lm *leafMap) deepestLeaf() *node { + lm.RLock() + defer lm.RUnlock() + max := big.NewInt(-1) var dLeaf *node - ls.smap.Range(func(h, n interface{}) bool { + lm.smap.Range(func(h, n interface{}) bool { if n == nil { return true } @@ -78,10 +84,13 @@ func (ls *leafMap) deepestLeaf() *node { return dLeaf } -func (ls *leafMap) toMap() map[common.Hash]*node { +func (lm *leafMap) toMap() map[common.Hash]*node { + lm.RLock() + defer lm.RUnlock() + mmap := make(map[common.Hash]*node) - ls.smap.Range(func(h, n interface{}) bool { + lm.smap.Range(func(h, n interface{}) bool { hash := h.(Hash) node := n.(*node) mmap[hash] = node @@ -91,10 +100,13 @@ func (ls *leafMap) toMap() map[common.Hash]*node { return mmap } -func (ls *leafMap) nodes() []*node { +func (lm *leafMap) nodes() []*node { + lm.RLock() + defer lm.RUnlock() + nodes := []*node{} - ls.smap.Range(func(h, n interface{}) bool { + lm.smap.Range(func(h, n interface{}) bool { node := n.(*node) nodes = append(nodes, node) return true diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index 8e9ba54b07..a7cec946ca 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -10,6 +10,7 @@ import ( "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/ed25519" "github.com/ChainSafe/gossamer/lib/keystore" @@ -628,6 +629,56 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { require.NoError(t, err) err = gs.VerifyBlockJustification(testHash, data) require.NotNil(t, err) + require.Equal(t, blocktree.ErrEndNodeNotFound, err) +} + +func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { + auths := []types.GrandpaVoter{ + { + Key: *kr.Alice().Public().(*ed25519.PublicKey), + }, + { + Key: *kr.Bob().Public().(*ed25519.PublicKey), + }, + { + Key: *kr.Charlie().Public().(*ed25519.PublicKey), + }, + } + + gs, st := newTestService(t) + err := st.Grandpa.SetNextChange(auths, big.NewInt(1)) + require.NoError(t, err) + + body, err := types.NewBodyFromBytes([]byte{0}) + require.NoError(t, err) + + block := &types.Block{ + Header: *testHeader, + Body: *body, + } + + err = st.Block.AddBlock(block) + require.NoError(t, err) + + err = st.Grandpa.IncrementSetID() + require.NoError(t, err) + + setID, err := st.Grandpa.GetCurrentSetID() + require.NoError(t, err) + require.Equal(t, uint64(1), setID) + + genhash := st.Block.GenesisHash() + round := uint64(2) + number := uint32(2) + + // use wrong hash, shouldn't verify + precommits := buildTestJustification(t, 2, round+1, setID, kr, precommit) + just := newJustification(round+1, testHash, number, precommits) + just.Commit.Precommits[0].Vote.Hash = genhash + data, err := scale.Marshal(*just) + require.NoError(t, err) + err = gs.VerifyBlockJustification(testHash, data) + require.NotNil(t, err) require.Equal(t, ErrPrecommitBlockMismatch, err) // use wrong round, shouldn't verify