Skip to content

Commit

Permalink
fix(lib/grandpa): capped number of tracked vote messages (#2485)
Browse files Browse the repository at this point in the history
- Vote messages tracker
  - Removes oldest vote message when tracker capacity is reached
  - Efficient removal of multiple messages at any place in the tracker queue (linked list) if they get processed
  - Efficient removal of oldest message
  - Uses a bit more space to store each block hash + authority ID, for each vote message
  - Order is not modified for the same vote message (same block hash and authority id)
- Discard vote messages for more than 1 round in the future from the state round (thanks [andresilva](https://github.com/andresilva))
- Discard vote messages for more than 1 round in the past from the state round (thanks [andresilva](https://github.com/andresilva))
- Disable `addCatchUpResponse` (not implemented yet) to avoid a possible memory leak/abuse, see #1531 
- Comment with issue number about the reputation change of peers for bad vote messages

Co-authored-by: Timothy Wu <[email protected]>
  • Loading branch information
qdm12 and timwu20 authored May 30, 2022
1 parent 1f20d98 commit d2ee47e
Show file tree
Hide file tree
Showing 5 changed files with 625 additions and 97 deletions.
62 changes: 29 additions & 33 deletions lib/grandpa/message_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto/ed25519"
"github.com/libp2p/go-libp2p-core/peer"
)

// tracker keeps track of messages that have been received, but have failed to
Expand All @@ -18,8 +18,8 @@ import (
type tracker struct {
blockState BlockState
handler *MessageHandler
// map of vote block hash -> array of VoteMessages for that hash
voteMessages map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage
votes votesTracker

// map of commit block hash to commit message
commitMessages map[common.Hash]*CommitMessage
mapLock sync.Mutex
Expand All @@ -32,10 +32,11 @@ type tracker struct {
}

func newTracker(bs BlockState, handler *MessageHandler) *tracker {
const votesCapacity = 1000
return &tracker{
blockState: bs,
handler: handler,
voteMessages: make(map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage),
votes: newVotesTracker(votesCapacity),
commitMessages: make(map[common.Hash]*CommitMessage),
mapLock: sync.Mutex{},
in: bs.GetImportedBlockNotifierChannel(),
Expand All @@ -53,21 +54,15 @@ func (t *tracker) stop() {
t.blockState.FreeImportedBlockNotifierChannel(t.in)
}

func (t *tracker) addVote(v *networkVoteMessage) {
if v.msg == nil {
func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) {
if message == nil {
return
}

t.mapLock.Lock()
defer t.mapLock.Unlock()

msgs, has := t.voteMessages[v.msg.Message.BlockHash]
if !has {
msgs = make(map[ed25519.PublicKeyBytes]*networkVoteMessage)
t.voteMessages[v.msg.Message.BlockHash] = msgs
}

msgs[v.msg.Message.AuthorityID] = v
t.votes.add(peerID, message)
}

func (t *tracker) addCommit(cm *CommitMessage) {
Expand All @@ -76,10 +71,11 @@ func (t *tracker) addCommit(cm *CommitMessage) {
t.commitMessages[cm.Vote.Hash] = cm
}

func (t *tracker) addCatchUpResponse(cr *CatchUpResponse) {
func (t *tracker) addCatchUpResponse(_ *CatchUpResponse) {
t.catchUpResponseMessageMutex.Lock()
defer t.catchUpResponseMessageMutex.Unlock()
t.catchUpResponseMessages[cr.Round] = cr
// uncomment when usage is setup properly, see #1531
// t.catchUpResponseMessages[cr.Round] = cr
}

func (t *tracker) handleBlocks() {
Expand Down Expand Up @@ -108,18 +104,18 @@ func (t *tracker) handleBlock(b *types.Block) {
defer t.mapLock.Unlock()

h := b.Header.Hash()
if vms, has := t.voteMessages[h]; has {
for _, v := range vms {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Warnf("failed to handle vote message %v: %s", v, err)
}
vms := t.votes.messages(h)
for _, v := range vms {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Warnf("failed to handle vote message %v: %s", v, err)
}

delete(t.voteMessages, h)
}

// delete block hash that may or may not be in the tracker.
t.votes.delete(h)

if cm, has := t.commitMessages[h]; has {
_, err := t.handler.handleMessage("", cm)
if err != nil {
Expand All @@ -134,17 +130,17 @@ func (t *tracker) handleTick() {
t.mapLock.Lock()
defer t.mapLock.Unlock()

for _, vms := range t.voteMessages {
for _, v := range vms {
for _, networkVoteMessage := range t.votes.networkVoteMessages() {
peerID := networkVoteMessage.from
message := networkVoteMessage.msg
_, err := t.handler.handleMessage(peerID, message)
if err != nil {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Debugf("failed to handle vote message %v: %s", v, err)
}
logger.Debugf("failed to handle vote message %v from peer id %s: %s", message, peerID, err)
}

if v.msg.Round < t.handler.grandpa.state.round && v.msg.SetID == t.handler.grandpa.state.setID {
delete(t.voteMessages, v.msg.Message.BlockHash)
}
if message.Round < t.handler.grandpa.state.round && message.SetID == t.handler.grandpa.state.setID {
t.votes.delete(message.Message.BlockHash)
}
}

Expand Down
71 changes: 37 additions & 34 deletions lib/grandpa/message_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ import (
"github.com/stretchr/testify/require"
)

// getMessageFromVotesTracker returns the vote message
// from the votes tracker for the given block hash and authority ID.
func getMessageFromVotesTracker(votes votesTracker,
blockHash common.Hash, authorityID ed25519.PublicKeyBytes) (
message *VoteMessage) {
authorityIDToElement, has := votes.mapping[blockHash]
if !has {
return nil
}

element, ok := authorityIDToElement[authorityID]
if !ok {
return nil
}

return element.Value.(networkVoteMessage).msg
}

func TestMessageTracker_ValidateMessage(t *testing.T) {
kr, err := keystore.NewEd25519Keyring()
require.NoError(t, err)
Expand All @@ -33,13 +51,11 @@ func TestMessageTracker_ValidateMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
require.Equal(t, expected, gs.tracker.voteMessages[fake.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID)
require.Equal(t, msg, voteMessage)
}

func TestMessageTracker_SendMessage(t *testing.T) {
Expand Down Expand Up @@ -72,13 +88,11 @@ func TestMessageTracker_SendMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Header: *next,
Expand Down Expand Up @@ -126,13 +140,11 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, ErrBlockDoesNotExist, err)
require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Header: *next,
Expand All @@ -147,7 +159,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
}
pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes())
require.True(t, has)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.voteMessages)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes)
}

func TestMessageTracker_MapInsideMap(t *testing.T) {
Expand All @@ -163,24 +175,19 @@ func TestMessageTracker_MapInsideMap(t *testing.T) {
}

hash := header.Hash()
_, ok := gs.tracker.voteMessages[hash]
require.False(t, ok)
messages := gs.tracker.votes.messages(hash)
require.Empty(t, messages)

gs.keypair = kr.Alice().(*ed25519.Keypair)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
_, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(header), prevote)
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})

voteMsgs, ok := gs.tracker.voteMessages[hash]
require.True(t, ok)
gs.tracker.addVote("", msg)

_, ok = voteMsgs[authorityID]
require.True(t, ok)
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID)
require.NotEmpty(t, voteMessage)
}

func TestMessageTracker_handleTick(t *testing.T) {
Expand All @@ -197,9 +204,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
BlockHash: testHash,
},
}
gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})
gs.tracker.addVote("", msg)

gs.tracker.handleTick()

Expand All @@ -212,7 +217,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
}

// shouldn't be deleted as round in message >= grandpa round
require.Equal(t, 1, len(gs.tracker.voteMessages[testHash]))
require.Len(t, gs.tracker.votes.messages(testHash), 1)

gs.state.round = 1
msg = &VoteMessage{
Expand All @@ -221,9 +226,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
BlockHash: testHash,
},
}
gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})
gs.tracker.addVote("", msg)

gs.tracker.handleTick()

Expand All @@ -235,5 +238,5 @@ func TestMessageTracker_handleTick(t *testing.T) {
}

// should be deleted as round in message < grandpa round
require.Empty(t, len(gs.tracker.voteMessages[testHash]))
require.Empty(t, gs.tracker.votes.messages(testHash))
}
76 changes: 46 additions & 30 deletions lib/grandpa/vote_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,51 +126,70 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro
// check for message signature
pk, err := ed25519.NewPublicKey(m.Message.AuthorityID[:])
if err != nil {
// TODO Affect peer reputation
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, err
}

err = validateMessageSignature(pk, m)
if err != nil {
// TODO Affect peer reputation
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, err
}

if m.SetID != s.state.setID {
return nil, ErrSetIDMismatch
}

// check that vote is for current round
if m.Round != s.state.round {
if m.Round < s.state.round {
// peer doesn't know round was finalised, send out another commit message
header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID)
if err != nil {
return nil, err
}
const maxRoundsLag = 1
minRoundAccepted := s.state.round - maxRoundsLag
if minRoundAccepted > s.state.round {
// we overflowed below 0 so set the minimum to 0.
minRoundAccepted = 0
}

cm, err := s.newCommitMessage(header, m.Round)
if err != nil {
return nil, err
}
const maxRoundsAhead = 1
maxRoundAccepted := s.state.round + maxRoundsAhead

// send finalised block from previous round to network
msg, err := cm.ToConsensusMessage()
if err != nil {
return nil, err
}
if m.Round < minRoundAccepted || m.Round > maxRoundAccepted {
// Discard message
// TODO: affect peer reputation, this is shameful impolite behaviour
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, nil //nolint:nilnil
}

if err = s.network.SendMessage(from, msg); err != nil {
logger.Warnf("failed to send CommitMessage: %s", err)
}
} else {
// round is higher than ours, perhaps we are behind. store vote in tracker for now
s.tracker.addVote(&networkVoteMessage{
from: from,
msg: m,
})
if m.Round < s.state.round {
// message round is lagging by 1
// peer doesn't know round was finalised, send out another commit message
header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID)
if err != nil {
return nil, err
}

cm, err := s.newCommitMessage(header, m.Round)
if err != nil {
return nil, err
}

// send finalised block from previous round to network
msg, err := cm.ToConsensusMessage()
if err != nil {
return nil, err
}

if err = s.network.SendMessage(from, msg); err != nil {
logger.Warnf("failed to send CommitMessage: %s", err)
}

// TODO: get justification if your round is lower, or just do catch-up? (#1815)
return nil, errRoundMismatch(m.Round, s.state.round)
} else if m.Round > s.state.round {
// Message round is higher by 1 than the round of our state,
// we may be lagging behind, so store the message in the tracker
// for processing later in the coming few milliseconds.
s.tracker.addVote(from, m)
return nil, errRoundMismatch(m.Round, s.state.round)
}

// check for equivocation ie. multiple votes within one subround
Expand All @@ -192,10 +211,7 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro
errors.Is(err, blocktree.ErrDescendantNotFound) ||
errors.Is(err, blocktree.ErrEndNodeNotFound) ||
errors.Is(err, blocktree.ErrStartNodeNotFound) {
s.tracker.addVote(&networkVoteMessage{
from: from,
msg: m,
})
s.tracker.addVote(from, m)
}
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit d2ee47e

Please sign in to comment.