diff --git a/dot/digest/digest_test.go b/dot/digest/digest_test.go index c0f6839a3f..c5baa65c3a 100644 --- a/dot/digest/digest_test.go +++ b/dot/digest/digest_test.go @@ -138,14 +138,14 @@ func TestHandler_GrandpaScheduledChange(t *testing.T) { require.NoError(t, err) headers := addTestBlocksToState(t, 2, handler.blockState) - for _, h := range headers { - handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), 0, 0) + for i, h := range headers { + handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), uint64(i), 0) } // authorities should change on start of block 3 from start headers = addTestBlocksToState(t, 1, handler.blockState) for _, h := range headers { - handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), 0, 0) + handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), 3, 0) } time.Sleep(time.Millisecond * 500) @@ -231,8 +231,8 @@ func TestHandler_GrandpaPauseAndResume(t *testing.T) { require.Equal(t, big.NewInt(int64(p.Delay)), nextPause) headers := addTestBlocksToState(t, 3, handler.blockState) - for _, h := range headers { - handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), 0, 0) + for i, h := range headers { + handler.blockState.(*state.BlockState).SetFinalizedHash(h.Hash(), uint64(i), 0) } time.Sleep(time.Millisecond * 100) diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 3b575e365d..91f0a21245 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -281,7 +281,12 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc return } - if !added { + // TODO: ensure grandpa stores *all* previously received votes and discards them + // only when they are for already finalised rounds; currently this causes issues + // because a vote might be received slightly too early, causing a round mismatch err, + // causing grandpa to discard the vote. + _, isConsensusMsg := msg.(*ConsensusMessage) + if !added && !isConsensusMsg { return } } diff --git a/dot/network/sync.go b/dot/network/sync.go index 36de6e405d..dd4cdccfd7 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -36,6 +36,16 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) +// SendBlockReqestByHash sends a block request to the network with the given block hash +func (s *Service) SendBlockReqestByHash(hash common.Hash) { + req := createBlockRequestWithHash(hash, blockRequestSize) + s.syncQueue.requestDataByHash.Delete(hash) + s.syncQueue.trySync(&syncRequest{ + req: req, + to: "", + }) +} + // handleSyncStream handles streams with the /sync/2 protocol ID func (s *Service) handleSyncStream(stream libp2pnetwork.Stream) { if stream == nil { @@ -537,7 +547,11 @@ func (q *syncQueue) pushResponse(resp *BlockResponseMessage, pid peer.ID) error } q.responses = sortResponses(q.responses) - logger.Debug("pushed block data to queue", "start", start, "end", end, "queue", q.stringifyResponseQueue()) + logger.Debug("pushed block data to queue", "start", start, "end", end, + "start hash", q.responses[0].Hash, + "end hash", q.responses[len(q.responses)-1].Hash, + "queue", q.stringifyResponseQueue(), + ) return nil } @@ -611,9 +625,10 @@ func (q *syncQueue) trySync(req *syncRequest) { logger.Trace("trying peers in prioritised order...") syncPeers := q.getSortedPeers() - for _, peer := range syncPeers { + for i, peer := range syncPeers { // if peer doesn't respond multiple times, then ignore them TODO: determine best values for this - if peer.score <= badPeerThreshold { + // TODO: if we only have a few peers, should we do this check at all? + if peer.score <= badPeerThreshold && i > q.s.cfg.MinPeers { break } @@ -647,9 +662,6 @@ func (q *syncQueue) trySync(req *syncRequest) { q.justificationRequestData.Store(startingBlockHash, reqdata) } - - req.to = "" - q.requestCh <- req } func (q *syncQueue) syncWithPeer(peer peer.ID, req *BlockRequestMessage) (*BlockResponseMessage, error) { @@ -737,7 +749,7 @@ func (q *syncQueue) handleBlockData(data []*types.BlockData) { end := data[len(data)-1].Number().Int64() if end <= finalised.Number.Int64() { - logger.Debug("ignoring block data that is below our head", "got", end, "head", finalised.Number.Int64()) + logger.Debug("ignoring block data that is below our finalised head", "got", end, "head", finalised.Number.Int64()) q.pushRequest(uint64(end+1), blockRequestBufferSize, "") return } @@ -844,21 +856,16 @@ func (q *syncQueue) handleBlockAnnounce(msg *BlockAnnounceMessage, from peer.ID) return } - if header.Number.Int64() <= q.goal { - return + if header.Number.Int64() > q.goal { + q.goal = header.Number.Int64() } - q.goal = header.Number.Int64() - - bestNum, err := q.s.blockState.BestBlockNumber() - if err != nil { - logger.Error("failed to get best block number", "error", err) - return + req := createBlockRequestWithHash(header.Hash(), blockRequestSize) + q.requestDataByHash.Delete(req) + q.requestCh <- &syncRequest{ + req: req, + to: from, } - - // TODO: if we're at the head, this should request by hash instead of number, since there will - // certainly be blocks with the same number. - q.pushRequest(uint64(bestNum.Int64()+1), blockRequestBufferSize, from) } func createBlockRequest(startInt int64, size uint32) *BlockRequestMessage { @@ -875,7 +882,7 @@ func createBlockRequest(startInt int64, size uint32) *BlockRequestMessage { RequestedData: RequestedDataHeader + RequestedDataBody + RequestedDataJustification, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 0, // ascending + Direction: 0, // TODO: define this somewhere Max: max, } @@ -896,7 +903,7 @@ func createBlockRequestWithHash(startHash common.Hash, size uint32) *BlockReques RequestedData: RequestedDataHeader + RequestedDataBody + RequestedDataJustification, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 0, // ascending + Direction: 0, // TODO: define this somewhere Max: max, } diff --git a/dot/network/sync_test.go b/dot/network/sync_test.go index 4529db0e7e..d519ab3e69 100644 --- a/dot/network/sync_test.go +++ b/dot/network/sync_test.go @@ -271,11 +271,12 @@ func TestSyncQueue_HandleBlockAnnounce(t *testing.T) { require.True(t, ok) require.Equal(t, 1, score.(int)) require.Equal(t, testBlockAnnounceMessage.Number.Int64(), q.goal) - require.Equal(t, 6, len(q.requestCh)) + require.Equal(t, 1, len(q.requestCh)) - head, err := q.s.blockState.BestBlockNumber() - require.NoError(t, err) - expected := createBlockRequest(head.Int64(), blockRequestSize) + header := &types.Header{ + Number: testBlockAnnounceMessage.Number, + } + expected := createBlockRequestWithHash(header.Hash(), blockRequestSize) req := <-q.requestCh require.Equal(t, &syncRequest{req: expected, to: testPeerID}, req) } diff --git a/dot/state/block.go b/dot/state/block.go index 195239f49b..6177c46bd7 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -432,8 +432,9 @@ func (bs *BlockState) SetFinalizedHash(hash common.Hash, round, setID uint64) er } } - go bs.notifyFinalized(hash, round, setID) if round > 0 { + go bs.notifyFinalized(hash, round, setID) + err := bs.SetRound(round) if err != nil { return err @@ -452,7 +453,7 @@ func (bs *BlockState) SetFinalizedHash(hash common.Hash, round, setID uint64) er return err } - logger.Trace("pruned block", "hash", rem) + logger.Trace("pruned block", "hash", rem, "number", header.Number) bs.pruneKeyCh <- header } diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index ace6e1067f..b0bd5b26c3 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -61,7 +61,7 @@ func TestFinalizedChannel(t *testing.T) { chain, _ := AddBlocksToState(t, bs, 3) for _, b := range chain { - bs.SetFinalizedHash(b.Hash(), 0, 0) + bs.SetFinalizedHash(b.Hash(), 1, 0) } for i := 0; i < 1; i++ { @@ -146,7 +146,7 @@ func TestFinalizedChannel_Multi(t *testing.T) { } time.Sleep(time.Millisecond * 10) - bs.SetFinalizedHash(chain[0].Hash(), 0, 0) + bs.SetFinalizedHash(chain[0].Hash(), 1, 0) wg.Wait() for _, id := range ids { diff --git a/dot/sync/message.go b/dot/sync/message.go index 12d338712c..9031301a5d 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -101,16 +101,16 @@ func (s *Service) CreateBlockResponse(blockRequest *network.BlockRequestMessage) responseData := []*types.BlockData{} switch blockRequest.Direction { - case 0: // ascending (ie child to parent) - for i := endHeader.Number.Int64(); i >= startHeader.Number.Int64(); i-- { + case 0: // ascending (ie parent to child) + for i := startHeader.Number.Int64(); i <= endHeader.Number.Int64(); i++ { blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) if err != nil { return nil, err } responseData = append(responseData, blockData) } - case 1: // descending (ie parent to child) - for i := startHeader.Number.Int64(); i <= endHeader.Number.Int64(); i++ { + case 1: // descending (ie child to parent) + for i := endHeader.Number.Int64(); i >= startHeader.Number.Int64(); i-- { blockData, err := s.getBlockData(big.NewInt(i), blockRequest.RequestedData) if err != nil { return nil, err diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index e0d6cba4da..0d38a43223 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -48,7 +48,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), } @@ -62,7 +62,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { RequestedData: 3, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 1, + Direction: 0, Max: optional.NewUint32(true, maxResponseSize+100), } @@ -87,7 +87,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { RequestedData: 3, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), } @@ -98,7 +98,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { require.Equal(t, big.NewInt(128), resp.BlockData[127].Number()) } -func TestService_CreateBlockResponse_Ascending(t *testing.T) { +func TestService_CreateBlockResponse_Descending(t *testing.T) { s := NewTestSyncer(t, false) addTestBlocksToState(t, int(maxResponseSize), s.blockState) @@ -112,7 +112,7 @@ func TestService_CreateBlockResponse_Ascending(t *testing.T) { RequestedData: 3, StartingBlock: start, EndBlockHash: optional.NewHash(false, common.Hash{}), - Direction: 0, + Direction: 1, Max: optional.NewUint32(false, 0), } @@ -169,7 +169,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 3, StartingBlock: start, EndBlockHash: optional.NewHash(true, endHash), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), }, expectedMsgValue: &network.BlockResponseMessage{ @@ -188,7 +188,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 1, StartingBlock: start, EndBlockHash: optional.NewHash(true, endHash), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), }, expectedMsgValue: &network.BlockResponseMessage{ @@ -207,7 +207,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 4, StartingBlock: start, EndBlockHash: optional.NewHash(true, endHash), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), }, expectedMsgValue: &network.BlockResponseMessage{ @@ -227,7 +227,7 @@ func TestService_CreateBlockResponse(t *testing.T) { RequestedData: 8, StartingBlock: start, EndBlockHash: optional.NewHash(true, endHash), - Direction: 1, + Direction: 0, Max: optional.NewUint32(false, 0), }, expectedMsgValue: &network.BlockResponseMessage{ diff --git a/go.sum b/go.sum index 69581e72c2..dc3f2bccbb 100644 --- a/go.sum +++ b/go.sum @@ -273,6 +273,7 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= diff --git a/lib/common/hash.go b/lib/common/hash.go index 23edcfb8e7..36f7bd41a3 100644 --- a/lib/common/hash.go +++ b/lib/common/hash.go @@ -35,6 +35,9 @@ const ( // Hash used to store a blake2b hash type Hash [32]byte +// EmptyHash is an empty [32]byte{} +var EmptyHash = Hash{} + // NewHash casts a byte array to a Hash // if the input is longer than 32 bytes, it takes the first 32 bytes func NewHash(in []byte) (res Hash) { diff --git a/lib/grandpa/errors.go b/lib/grandpa/errors.go index 1d51970d0e..a6386bd531 100644 --- a/lib/grandpa/errors.go +++ b/lib/grandpa/errors.go @@ -100,4 +100,6 @@ var ( // ErrAuthorityNotInSet is returned when a precommit within a justification is signed by a key not in the authority set ErrAuthorityNotInSet = errors.New("authority is not in set") + + errVoteExists = errors.New("already have vote") ) diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index 60f5d08611..eb48d4d845 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "errors" + "fmt" "math/big" "os" "sync" @@ -39,7 +40,7 @@ const ( ) var ( - interval = time.Second // TODO: make this configurable; currently 1s is same as substrate; total round length is then 2s + interval = time.Second // TODO: make this configurable; currently 1s is same as substrate; total round length is interval * 2 logger = log.New("pkg", "grandpa") ) @@ -62,20 +63,18 @@ type Service struct { network Network // current state information - state *State // current state - prevotes map[ed25519.PublicKeyBytes]*Vote // pre-votes for the current round - precommits map[ed25519.PublicKeyBytes]*Vote // pre-commits for the current round - pvJustifications map[common.Hash][]*SignedPrecommit // pre-vote justifications for the current round - pcJustifications map[common.Hash][]*SignedPrecommit // pre-commit justifications for the current round - pvEquivocations map[ed25519.PublicKeyBytes][]*Vote // equivocatory votes for current pre-vote stage - pcEquivocations map[ed25519.PublicKeyBytes][]*Vote // equivocatory votes for current pre-commit stage - tracker *tracker // tracker of vote messages we may need in the future - head *types.Header // most recently finalised block + state *State // current state + prevotes *sync.Map // map[ed25519.PublicKeyBytes]*SignedVote // pre-votes for the current round + precommits *sync.Map // map[ed25519.PublicKeyBytes]*SignedVote // pre-commits for the current round + pvEquivocations map[ed25519.PublicKeyBytes][]*SignedVote // equivocatory votes for current pre-vote stage + pcEquivocations map[ed25519.PublicKeyBytes][]*SignedVote // equivocatory votes for current pre-commit stage + tracker *tracker // tracker of vote messages we may need in the future + head *types.Header // most recently finalised block // historical information - preVotedBlock map[uint64]*Vote // map of round number -> pre-voted block - bestFinalCandidate map[uint64]*Vote // map of round number -> best final candidate - justification map[uint64][]*SignedPrecommit // map of round number -> precommit round justification + preVotedBlock map[uint64]*Vote // map of round number -> pre-voted block + bestFinalCandidate map[uint64]*Vote // map of round number -> best final candidate + justification map[uint64][]*SignedVote // map of round number -> precommit round justification // channels for communication with other services in chan GrandpaMessage // only used to receive *VoteMessage @@ -156,15 +155,13 @@ func NewService(cfg *Config) (*Service, error) { digestHandler: cfg.DigestHandler, keypair: cfg.Keypair, authority: cfg.Authority, - prevotes: make(map[ed25519.PublicKeyBytes]*Vote), - precommits: make(map[ed25519.PublicKeyBytes]*Vote), - pvJustifications: make(map[common.Hash][]*SignedPrecommit), - pcJustifications: make(map[common.Hash][]*SignedPrecommit), - pvEquivocations: make(map[ed25519.PublicKeyBytes][]*Vote), - pcEquivocations: make(map[ed25519.PublicKeyBytes][]*Vote), + prevotes: new(sync.Map), + precommits: new(sync.Map), + pvEquivocations: make(map[ed25519.PublicKeyBytes][]*SignedVote), + pcEquivocations: make(map[ed25519.PublicKeyBytes][]*SignedVote), preVotedBlock: make(map[uint64]*Vote), bestFinalCandidate: make(map[uint64]*Vote), - justification: make(map[uint64][]*SignedPrecommit), + justification: make(map[uint64][]*SignedVote), head: head, in: make(chan GrandpaMessage, 128), resumed: make(chan struct{}), @@ -271,15 +268,18 @@ func (s *Service) updateAuthorities() error { func (s *Service) publicKeyBytes() ed25519.PublicKeyBytes { return s.keypair.Public().(*ed25519.PublicKey).AsBytes() } - -// initiate initates a GRANDPA round -func (s *Service) initiate() error { +func (s *Service) initiateRound() error { // if there is an authority change, execute it err := s.updateAuthorities() if err != nil { return err } + s.head, err = s.blockState.GetFinalizedHeader(s.state.round, s.state.setID) + if err != nil { + return err + } + if s.state.round == 0 { s.chanLock.Lock() s.mapLock.Lock() @@ -292,18 +292,16 @@ func (s *Service) initiate() error { // make sure no votes can be validated while we are incrementing rounds s.roundLock.Lock() s.state.round++ - logger.Trace("incrementing grandpa round", "next round", s.state.round) - + logger.Debug("incrementing grandpa round", "next round", s.state.round) if s.tracker != nil { s.tracker.stop() } - s.prevotes = make(map[ed25519.PublicKeyBytes]*Vote) - s.precommits = make(map[ed25519.PublicKeyBytes]*Vote) - s.pcJustifications = make(map[common.Hash][]*SignedPrecommit) - s.pvEquivocations = make(map[ed25519.PublicKeyBytes][]*Vote) - s.pcEquivocations = make(map[ed25519.PublicKeyBytes][]*Vote) - s.justification = make(map[uint64][]*SignedPrecommit) + s.prevotes = new(sync.Map) + s.precommits = new(sync.Map) + s.pvEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote) + s.pcEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote) + s.justification = make(map[uint64][]*SignedVote) s.tracker, err = newTracker(s.blockState, s.in) if err != nil { return err @@ -312,22 +310,31 @@ func (s *Service) initiate() error { logger.Trace("started message tracker") s.roundLock.Unlock() - // don't begin grandpa until we are at block 1 - h, err := s.blockState.BestBlockHeader() + best, err := s.blockState.BestBlockHeader() if err != nil { return err } - if h != nil && h.Number.Int64() == 0 { - err = s.waitForFirstBlock() + if best.Number.Int64() > 0 { + return nil + } + + // don't begin grandpa until we are at block 1 + return s.waitForFirstBlock() +} + +// initiate initates the grandpa service to begin voting in sequential rounds +func (s *Service) initiate() error { + for { + err := s.initiateRound() if err != nil { + logger.Warn("failed to initiate round", "round", s.state.round, "error", err) return err } - } - for { err = s.playGrandpaRound() if err == ErrServicePaused { + logger.Info("service paused") // wait for service to un-pause <-s.resumed err = s.initiate() @@ -338,12 +345,7 @@ func (s *Service) initiate() error { } if s.ctx.Err() != nil { - return nil - } - - err = s.initiate() - if err != nil { - return err + return errors.New("context cancelled") } } } @@ -378,40 +380,64 @@ func (s *Service) waitForFirstBlock() error { return nil } -// playGrandpaRound executes a round of GRANDPA -// at the end of this round, a block will be finalised. -func (s *Service) playGrandpaRound() error { - logger.Debug("starting round", "round", s.state.round, "setID", s.state.setID) - - // save start time - start := time.Now() - +func (s *Service) handleIsPrimary() (bool, error) { // derive primary primary := s.derivePrimary() // if primary, broadcast the best final candidate from the previous round - if bytes.Equal(primary.Key.Encode(), s.keypair.Public().Encode()) { + // otherwise, do nothing + if !bytes.Equal(primary.Key.Encode(), s.keypair.Public().Encode()) { + return false, nil + } + + if s.head.Number.Int64() > 0 { + // send finalised block from previous round to network msg, err := s.newCommitMessage(s.head, s.state.round-1).ToConsensusMessage() if err != nil { - logger.Error("failed to encode finalisation message", "error", err) - } else { - s.network.SendMessage(msg) + return false, fmt.Errorf("failed to encode finalisation message: %w", err) } - primProposal, err := s.createVoteMessage(&Vote{ - hash: s.head.Hash(), - number: uint32(s.head.Number.Int64()), - }, primaryProposal, s.keypair) - if err != nil { - logger.Error("failed to create primary proposal message", "error", err) - } else { - msg, err = primProposal.ToConsensusMessage() - if err != nil { - logger.Error("failed to encode finalisation message", "error", err) - } else { - s.network.SendMessage(msg) - } - } + s.network.SendMessage(msg) + } + + best, err := s.blockState.BestBlockHeader() + if err != nil { + return false, err + } + + pv := &Vote{ + hash: best.Hash(), + number: uint32(best.Number.Int64()), + } + + // send primary prevote message to network + spv, primProposal, err := s.createSignedVoteAndVoteMessage(pv, primaryProposal) + if err != nil { + return false, fmt.Errorf("failed to create primary proposal message: %w", err) + } + + s.prevotes.Store(s.publicKeyBytes(), spv) + + msg, err := primProposal.ToConsensusMessage() + if err != nil { + return false, fmt.Errorf("failed to encode finalisation message: %w", err) + } + + s.network.SendMessage(msg) + return true, nil +} + +// playGrandpaRound executes a round of GRANDPA +// at the end of this round, a block will be finalised. +func (s *Service) playGrandpaRound() error { + logger.Debug("starting round", "round", s.state.round, "setID", s.state.setID) + + // save start time + start := time.Now() + + isPrimary, err := s.handleIsPrimary() + if err != nil { + return err } logger.Debug("receiving pre-vote messages...") @@ -445,33 +471,20 @@ func (s *Service) playGrandpaRound() error { return err } - s.mapLock.Lock() - s.prevotes[s.publicKeyBytes()] = pv - logger.Debug("sending pre-vote message...", "vote", pv, "prevotes", s.prevotes) - s.mapLock.Unlock() - - finalised := false - - // continue to send prevote messages until round is done - go func(finalised *bool) { - for { - if s.paused.Load().(bool) { - return - } + spv, vm, err := s.createSignedVoteAndVoteMessage(pv, prevote) + if err != nil { + return err + } - if *finalised { - return - } + if !isPrimary { + s.prevotes.Store(s.publicKeyBytes(), spv) + } + logger.Debug("sending pre-vote message...", "vote", pv) - err = s.sendMessage(pv, prevote) - if err != nil { - logger.Error("could not send prevote message", "error", err) - } + roundComplete := make(chan struct{}) - time.Sleep(time.Second * 5) - logger.Trace("sent pre-vote message...", "vote", pv, "prevotes", s.prevotes) - } - }(&finalised) + // continue to send prevote messages until round is done + go s.sendVoteMessage(prevote, vm, roundComplete) logger.Debug("receiving pre-commit messages...") @@ -488,7 +501,7 @@ func (s *Service) playGrandpaRound() error { return false }) - time.Sleep(interval * 2) + time.Sleep(interval) if s.paused.Load().(bool) { return ErrServicePaused @@ -500,31 +513,16 @@ func (s *Service) playGrandpaRound() error { return err } - s.mapLock.Lock() - s.precommits[s.publicKeyBytes()] = pc - logger.Debug("sending pre-commit message...", "vote", pc, "precommits", s.precommits) - s.mapLock.Unlock() - - // continue to send precommit messages until round is done - go func(finalised *bool) { - for { - if s.paused.Load().(bool) { - return - } - - if *finalised { - return - } + spc, pcm, err := s.createSignedVoteAndVoteMessage(pc, precommit) + if err != nil { + return err + } - err = s.sendMessage(pc, precommit) - if err != nil { - logger.Error("could not send precommit message", "error", err) - } + s.precommits.Store(s.publicKeyBytes(), spc) + logger.Debug("sending pre-commit message...", "vote", pc) - time.Sleep(time.Second * 5) - logger.Trace("sent pre-commit message...", "vote", pc, "precommits", s.precommits) - } - }(&finalised) + // continue to send precommit messages until round is done + go s.sendVoteMessage(precommit, pcm, roundComplete) go func() { // receive messages until current round is completable and previous round is finalisable @@ -562,42 +560,76 @@ func (s *Service) playGrandpaRound() error { }) }() + time.Sleep(interval) + err = s.attemptToFinalize() if err != nil { - log.Error("failed to finalise", "error", err) + logger.Error("failed to finalise", "error", err) return err } - finalised = true + close(roundComplete) return nil } +func (s *Service) sendVoteMessage(stage subround, msg *VoteMessage, roundComplete <-chan struct{}) { + ticker := time.NewTicker(interval * 4) + defer ticker.Stop() + + for { + if s.paused.Load().(bool) { + return + } + + err := s.sendMessage(msg) + if err != nil { + logger.Warn("could not send message", "stage", stage, "error", err) + } + + logger.Trace("sent vote message", "stage", stage, "vote", msg) + select { + case <-roundComplete: + return + case <-ticker.C: + } + } +} + // attemptToFinalize loops until the round is finalisable func (s *Service) attemptToFinalize() error { - if s.paused.Load().(bool) { - return ErrServicePaused - } + ticker := time.NewTicker(interval / 100) - if s.ctx.Err() != nil { - return nil - } + for { + select { + case <-s.ctx.Done(): + return errors.New("context cancelled") + case <-ticker.C: + } - has, _ := s.blockState.HasFinalizedBlock(s.state.round, s.state.setID) - if has { - return nil // a block was finalised, seems like we missed some messages - } + if s.paused.Load().(bool) { + return ErrServicePaused + } - bfc, err := s.getBestFinalCandidate() - if err != nil { - return err - } + has, _ := s.blockState.HasFinalizedBlock(s.state.round, s.state.setID) + if has { + logger.Debug("block was finalised!", "round", s.state.round) + return nil // a block was finalised, seems like we missed some messages + } - pc, err := s.getTotalVotesForBlock(bfc.hash, precommit) - if err != nil { - return err - } + bfc, err := s.getBestFinalCandidate() + if err != nil { + return err + } + + pc, err := s.getTotalVotesForBlock(bfc.hash, precommit) + if err != nil { + return err + } + + if bfc.number < uint32(s.head.Number.Int64()) || pc < s.state.threshold() { + continue + } - if bfc.number >= uint32(s.head.Number.Int64()) && pc >= s.state.threshold() { err = s.finalise() if err != nil { return err @@ -605,19 +637,56 @@ func (s *Service) attemptToFinalize() error { // if we haven't received a finalisation message for this block yet, broadcast a finalisation message votes := s.getDirectVotes(precommit) - logger.Debug("finalised block!!!", "setID", s.state.setID, "round", s.state.round, "hash", s.head.Hash(), - "precommits #", pc, "votes for bfc #", votes[*bfc], "total votes for bfc", pc, "precommits", s.precommits) - msg, err := s.newCommitMessage(s.head, s.state.round).ToConsensusMessage() + logger.Debug("finalised block!!!", + "setID", s.state.setID, + "round", s.state.round, + "hash", s.head.Hash(), + "precommits #", pc, + "direct votes for bfc #", votes[*bfc], + "total votes for bfc", pc, + "precommits", s.precommits, + "justification count", len(s.justification[s.state.round]), + ) + + cm := s.newCommitMessage(s.head, s.state.round) + msg, err := cm.ToConsensusMessage() if err != nil { return err } + logger.Debug("sending CommitMessage", "msg", cm) s.network.SendMessage(msg) return nil } +} + +func (s *Service) loadVote(key ed25519.PublicKeyBytes, stage subround) (*SignedVote, bool) { + var ( + v interface{} + has bool + ) + + switch stage { + case prevote, primaryProposal: + v, has = s.prevotes.Load(key) + case precommit: + v, has = s.precommits.Load(key) + } + + if !has { + return nil, false + } + + return v.(*SignedVote), true +} - time.Sleep(time.Millisecond * 10) - return s.attemptToFinalize() +func (s *Service) deleteVote(key ed25519.PublicKeyBytes, stage subround) { + switch stage { + case prevote, primaryProposal: + s.prevotes.Delete(key) + case precommit: + s.precommits.Delete(key) + } } // determinePreVote determines what block is our pre-voted block for the current round @@ -627,12 +696,9 @@ func (s *Service) determinePreVote() (*Vote, error) { // if we receive a vote message from the primary with a block that's greater than or equal to the current pre-voted block // and greater than the best final candidate from the last round, we choose that. // otherwise, we simply choose the head of our chain. - s.mapLock.Lock() - prm := s.prevotes[s.derivePrimary().PublicKeyBytes()] - s.mapLock.Unlock() - - if prm != nil && prm.number >= uint32(s.head.Number.Int64()) { - vote = prm + prm, has := s.loadVote(s.derivePrimary().PublicKeyBytes(), prevote) + if has && prm.Vote.number >= uint32(s.head.Number.Int64()) { + vote = prm.Vote } else { header, err := s.blockState.BestBlockHeader() if err != nil { @@ -748,19 +814,31 @@ func (s *Service) finalise() error { // set best final candidate s.bestFinalCandidate[s.state.round] = bfc - // set justification - s.justification[s.state.round] = s.pcJustifications[bfc.hash] + // create prevote justification ie. list of all signed prevotes for the bfc + pvs, err := s.createJustification(bfc.hash, prevote) + if err != nil { + return err + } - pvj, err := newJustification(s.state.round, bfc.hash, bfc.number, s.pvJustifications[bfc.hash]).Encode() + // create precommit justification ie. list of all signed precommits for the bfc + pcs, err := s.createJustification(bfc.hash, precommit) if err != nil { return err } - pcj, err := newJustification(s.state.round, bfc.hash, bfc.number, s.pcJustifications[bfc.hash]).Encode() + pvj, err := newJustification(s.state.round, bfc.hash, bfc.number, pvs).Encode() if err != nil { return err } + pcj, err := newJustification(s.state.round, bfc.hash, bfc.number, pcs).Encode() + if err != nil { + return err + } + + // cache justification + s.justification[s.state.round] = pcs + err = s.blockState.SetJustification(bfc.hash, append(pvj, pcj...)) if err != nil { return err @@ -781,6 +859,48 @@ func (s *Service) finalise() error { return s.blockState.SetFinalizedHash(bfc.hash, 0, 0) } +// createJustification collects the signed precommits received for this round and turns them into +// a justification by adding all signed precommits that are for the best finalised candidate or +// a descendent of the bfc +func (s *Service) createJustification(bfc common.Hash, stage subround) ([]*SignedVote, error) { + var ( + spc *sync.Map + err error + just []*SignedVote + ) + + switch stage { + case prevote: + spc = s.prevotes + case precommit: + spc = s.precommits + } + + // TODO: use equivacatory votes to create justification as well + spc.Range(func(_, value interface{}) bool { + pc := value.(*SignedVote) + var isDescendant bool + + isDescendant, err = s.blockState.IsDescendantOf(bfc, pc.Vote.hash) + if err != nil { + return false + } + + if !isDescendant { + return true + } + + just = append(just, pc) + return true + }) + + if err != nil { + return nil, err + } + + return just, nil +} + // derivePrimary returns the primary for the current round func (s *Service) derivePrimary() *Voter { return s.state.voters[s.state.round%uint64(len(s.state.voters))] @@ -858,12 +978,21 @@ func (s *Service) getBestFinalCandidate() (*Vote, error) { // isCompletable returns true if the round is completable, false otherwise func (s *Service) isCompletable() (bool, error) { - votes := s.getVotes(precommit) + // haven't received enough votes, not completable + if uint64(s.lenVotes(precommit)+len(s.pcEquivocations)) < s.state.threshold() { + return false, nil + } + prevoted, err := s.getPreVotedBlock() if err != nil { return false, err } + votes := s.getVotes(precommit) + + // for each block with at least 1 vote that's a descendant of the prevoted block, + // check that (total precommits - total pc equivocations - precommits for that block) >= 2/3 |V| + // ie. there must not be a descendent of the prevotes block that is preferred for _, v := range votes { if prevoted.hash == v.hash { continue @@ -879,13 +1008,12 @@ func (s *Service) isCompletable() (bool, error) { continue } - // if it's a descendant, check if has >=2/3 votes c, err := s.getTotalVotesForBlock(v.hash, precommit) if err != nil { return false, err } - if c > s.state.threshold() { + if uint64(len(votes)-len(s.pcEquivocations))-c < s.state.threshold() { // round isn't completable return false, nil } @@ -1114,19 +1242,18 @@ func (s *Service) getVotesForBlock(hash common.Hash, stage subround) (uint64, er func (s *Service) getDirectVotes(stage subround) map[Vote]uint64 { votes := make(map[Vote]uint64) - var src map[ed25519.PublicKeyBytes]*Vote + var src *sync.Map if stage == prevote { src = s.prevotes } else { src = s.precommits } - s.mapLock.Lock() - defer s.mapLock.Unlock() - - for _, v := range src { - votes[*v]++ - } + src.Range(func(_, value interface{}) bool { + sv := value.(*SignedVote) + votes[*sv.Vote]++ + return true + }) return votes } @@ -1170,3 +1297,22 @@ func (s *Service) findParentWithNumber(v *Vote, n uint32) (*Vote, error) { return NewVoteFromHeader(b), nil } + +func (s *Service) lenVotes(stage subround) int { + var count int + + switch stage { + case prevote, primaryProposal: + s.prevotes.Range(func(_, _ interface{}) bool { + count++ + return true + }) + case precommit: + s.precommits.Range(func(_, _ interface{}) bool { + count++ + return true + }) + } + + return count +} diff --git a/lib/grandpa/grandpa_test.go b/lib/grandpa/grandpa_test.go index 6485fbfc9c..8a00793cbc 100644 --- a/lib/grandpa/grandpa_test.go +++ b/lib/grandpa/grandpa_test.go @@ -148,9 +148,13 @@ func TestGetDirectVotes(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 5 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -178,9 +182,13 @@ func TestGetVotesForBlock_NoDescendantVotes(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 5 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -216,11 +224,17 @@ func TestGetVotesForBlock_DescendantVotes(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 5 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -261,11 +275,17 @@ func TestGetPossibleSelectedAncestors_SameAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -311,11 +331,17 @@ func TestGetPossibleSelectedAncestors_VaryingAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -369,13 +395,21 @@ func TestGetPossibleSelectedAncestors_VaryingAncestor_MoreBranches(t *testing.T) voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else if i < 8 { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } else { - gs.prevotes[voter] = voteD + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteD, + }) } } @@ -418,9 +452,13 @@ func TestGetPossibleSelectedBlocks_OneBlock(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -453,11 +491,17 @@ func TestGetPossibleSelectedBlocks_EqualVotes_SameAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -496,11 +540,17 @@ func TestGetPossibleSelectedBlocks_EqualVotes_VaryingAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -534,15 +584,22 @@ func TestGetPossibleSelectedBlocks_OneThirdEquivocating(t *testing.T) { voteB, err := NewVoteFromHash(leaves[1], st.Block) require.NoError(t, err) + svA := &SignedVote{ + Vote: voteA, + } + svB := &SignedVote{ + Vote: voteB, + } + for i, k := range kr.Keys { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, svA) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, svB) } else { - gs.pvEquivocations[voter] = []*Vote{voteA, voteB} + gs.pvEquivocations[voter] = []*SignedVote{svA, svB} } } @@ -568,21 +625,30 @@ func TestGetPossibleSelectedBlocks_MoreThanOneThirdEquivocating(t *testing.T) { voteC, err := NewVoteFromHash(leaves[2], st.Block) require.NoError(t, err) + svA := &SignedVote{ + Vote: voteA, + } + svB := &SignedVote{ + Vote: voteB, + } + for i, k := range kr.Keys { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 2 { // 2 votes for A - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, svA) } else if i < 4 { // 2 votes for B - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, svB) } else if i < 5 { // 1 vote for C - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } else { // 4 equivocators - gs.pvEquivocations[voter] = []*Vote{voteA, voteB} + gs.pvEquivocations[voter] = []*SignedVote{svA, svB} } } @@ -608,9 +674,13 @@ func TestGetPreVotedBlock_OneBlock(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -643,11 +713,17 @@ func TestGetPreVotedBlock_MultipleCandidates(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 3 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 6 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } @@ -698,17 +774,29 @@ func TestGetPreVotedBlock_EvenMoreCandidates(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 2 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 4 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else if i < 6 { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } else if i < 7 { - gs.prevotes[voter] = voteD + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteD, + }) } else if i < 8 { - gs.prevotes[voter] = voteE + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteE, + }) } else { - gs.prevotes[voter] = voteF + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteF, + }) } } @@ -741,9 +829,19 @@ func TestIsCompletable(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -791,11 +889,19 @@ func TestGetBestFinalCandidate_OneBlock(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA - gs.precommits[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -826,11 +932,19 @@ func TestGetBestFinalCandidate_PrecommitAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA - gs.precommits[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteC, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -858,10 +972,16 @@ func TestGetBestFinalCandidate_NoPrecommit(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -889,11 +1009,19 @@ func TestGetBestFinalCandidate_PrecommitOnAnotherChain(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteA, + }) } } @@ -926,11 +1054,15 @@ func TestDeterminePreVote_WithPrimaryPreVote(t *testing.T) { state.AddBlocksToState(t, st.Block, 1) primary := gs.derivePrimary().PublicKeyBytes() - gs.prevotes[primary] = NewVoteFromHeader(header) + gs.prevotes.Store(primary, &SignedVote{ + Vote: NewVoteFromHeader(header), + }) pv, err := gs.determinePreVote() require.NoError(t, err) - require.Equal(t, gs.prevotes[primary], pv) + p, has := gs.prevotes.Load(primary) + require.True(t, has) + require.Equal(t, pv, p.(*SignedVote).Vote) } func TestDeterminePreVote_WithInvalidPrimaryPreVote(t *testing.T) { @@ -941,7 +1073,9 @@ func TestDeterminePreVote_WithInvalidPrimaryPreVote(t *testing.T) { require.NoError(t, err) primary := gs.derivePrimary().PublicKeyBytes() - gs.prevotes[primary] = NewVoteFromHeader(header) + gs.prevotes.Store(primary, &SignedVote{ + Vote: NewVoteFromHeader(header), + }) state.AddBlocksToState(t, st.Block, 5) gs.head, err = st.Block.BestBlockHeader() @@ -969,11 +1103,19 @@ func TestIsFinalisable_True(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA - gs.precommits[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -999,11 +1141,19 @@ func TestIsFinalisable_False(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 6 { - gs.prevotes[voter] = voteA - gs.precommits[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteA, + }) } else { - gs.prevotes[voter] = voteB - gs.precommits[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) + gs.precommits.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -1036,9 +1186,13 @@ func TestGetGrandpaGHOST_CommonAncestor(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 4 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 5 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } } @@ -1073,11 +1227,17 @@ func TestGetGrandpaGHOST_MultipleCandidates(t *testing.T) { voter := k.Public().(*ed25519.PublicKey).AsBytes() if i < 1 { - gs.prevotes[voter] = voteA + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteA, + }) } else if i < 2 { - gs.prevotes[voter] = voteB + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteB, + }) } else if i < 3 { - gs.prevotes[voter] = voteC + gs.prevotes.Store(voter, &SignedVote{ + Vote: voteC, + }) } } diff --git a/lib/grandpa/message.go b/lib/grandpa/message.go index 00612a6da9..5e0572f610 100644 --- a/lib/grandpa/message.go +++ b/lib/grandpa/message.go @@ -64,7 +64,7 @@ type SignedMessage struct { // String returns the SignedMessage as a string func (m *SignedMessage) String() string { - return fmt.Sprintf("hash=%s number=%d authorityID=0x%x", m.Hash, m.Number, m.AuthorityID) + return fmt.Sprintf("stage=%s hash=%s number=%d authorityID=%s", m.Stage, m.Hash, m.Number, m.AuthorityID) } // Decode SCALE decodes the data into a SignedMessage @@ -283,7 +283,7 @@ func (s *Service) newCommitMessage(header *types.Header, round uint64) *CommitMe } } -func justificationToCompact(just []*SignedPrecommit) ([]*Vote, []*AuthData) { +func justificationToCompact(just []*SignedVote) ([]*Vote, []*AuthData) { precommits := make([]*Vote, len(just)) authData := make([]*AuthData, len(just)) @@ -330,8 +330,8 @@ func (r *catchUpRequest) ToConsensusMessage() (*ConsensusMessage, error) { type catchUpResponse struct { SetID uint64 Round uint64 - PreVoteJustification []*SignedPrecommit - PreCommitJustification []*SignedPrecommit + PreVoteJustification []*SignedVote + PreCommitJustification []*SignedVote Hash common.Hash Number uint32 } @@ -363,17 +363,17 @@ func (s *Service) newCatchUpResponse(round, setID uint64) (*catchUpResponse, err return nil, err } - d, err := sd.Decode([]*SignedPrecommit{}) + d, err := sd.Decode([]*SignedVote{}) if err != nil { return nil, err } - pvj := d.([]*SignedPrecommit) + pvj := d.([]*SignedVote) - d, err = sd.Decode([]*SignedPrecommit{}) + d, err = sd.Decode([]*SignedVote{}) if err != nil { return nil, err } - pcj := d.([]*SignedPrecommit) + pcj := d.([]*SignedVote) return &catchUpResponse{ SetID: setID, diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index 6627e80b73..6b990b21e0 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -110,29 +110,6 @@ func (h *MessageHandler) handleNeighbourMessage(from peer.ID, msg *NeighbourMess logger.Debug("got neighbour message", "number", msg.Number, "set id", msg.SetID, "round", msg.Round) h.grandpa.network.SendJustificationRequest(from, msg.Number) - - // don't finalise too close to head, until we add justification request + verification functionality. - // this prevents us from marking the wrong block as final and getting stuck on the wrong chain - if uint32(head.Int64())-4 < msg.Number { - return nil - } - - // TODO: instead of assuming the finalised hash is the one we currently know about, - // request the justification from the network before setting it as finalised. - hash, err := h.grandpa.blockState.GetHashByNumber(big.NewInt(int64(msg.Number))) - if err != nil { - return err - } - - if err = h.grandpa.blockState.SetFinalizedHash(hash, msg.Round, msg.SetID); err != nil { - return err - } - - if err = h.grandpa.blockState.SetFinalizedHash(hash, 0, 0); err != nil { - return err - } - - logger.Info("🔨 finalised block", "number", msg.Number, "hash", hash) return nil } @@ -155,10 +132,12 @@ func (h *MessageHandler) handleCommitMessage(msg *CommitMessage) (*ConsensusMess return nil, err } - // set latest finalised head in db - err = h.blockState.SetFinalizedHash(msg.Vote.hash, 0, 0) - if err != nil { - return nil, err + if msg.Round >= h.grandpa.state.round { + // set latest finalised head in db + err = h.blockState.SetFinalizedHash(msg.Vote.hash, 0, 0) + if err != nil { + return nil, err + } } // check if msg has same setID but is 2 or more rounds ahead of us, if so, return catch-up request to send @@ -174,7 +153,12 @@ func (h *MessageHandler) handleCommitMessage(msg *CommitMessage) (*ConsensusMess } func (h *MessageHandler) handleCatchUpRequest(msg *catchUpRequest) (*ConsensusMessage, error) { + if !h.grandpa.authority { + return nil, nil + } + logger.Debug("received catch up request", "round", msg.Round, "setID", msg.SetID) + if msg.SetID != h.grandpa.state.setID { return nil, ErrSetIDMismatch } @@ -193,6 +177,10 @@ func (h *MessageHandler) handleCatchUpRequest(msg *catchUpRequest) (*ConsensusMe } func (h *MessageHandler) handleCatchUpResponse(msg *catchUpResponse) error { + if !h.grandpa.authority { + return nil + } + logger.Debug("received catch up response", "round", msg.Round, "setID", msg.SetID, "hash", msg.Hash) // if we aren't currently expecting a catch up response, return @@ -267,7 +255,7 @@ func (h *MessageHandler) verifyCommitMessageJustification(fm *CommitMessage) err count := 0 for i, pc := range fm.Precommits { - just := &SignedPrecommit{ + just := &SignedVote{ Vote: pc, Signature: fm.AuthData[i].Signature, AuthorityID: fm.AuthData[i].AuthorityID, @@ -278,7 +266,12 @@ func (h *MessageHandler) verifyCommitMessageJustification(fm *CommitMessage) err continue } - if just.Vote.hash == fm.Vote.hash && just.Vote.number == fm.Vote.number { + isDescendant, err := h.blockState.IsDescendantOf(fm.Vote.hash, just.Vote.hash) + if err != nil { + logger.Warn("verifyCommitMessageJustification", "error", err) + } + + if isDescendant { count++ } } @@ -286,9 +279,11 @@ func (h *MessageHandler) verifyCommitMessageJustification(fm *CommitMessage) err // confirm total # signatures >= grandpa threshold if uint64(count) < h.grandpa.state.threshold() { logger.Debug("minimum votes not met for finalisation message", "votes needed", h.grandpa.state.threshold(), - "votes received", len(fm.Precommits)) + "votes received", count) return ErrMinVotesNotMet } + + logger.Debug("validated commit message", "msg", fm) return nil } @@ -341,7 +336,7 @@ func (h *MessageHandler) verifyPreCommitJustification(msg *catchUpResponse) erro return nil } -func (h *MessageHandler) verifyJustification(just *SignedPrecommit, round, setID uint64, stage subround) error { +func (h *MessageHandler) verifyJustification(just *SignedVote, round, setID uint64, stage subround) error { // verify signature msg, err := scale.Encode(&FullVote{ Stage: stage, diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index ef385b4908..0b33342cc2 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -46,10 +46,10 @@ var testBlock = &types.Block{ var testHash = testHeader.Hash() -func buildTestJustification(t *testing.T, qty int, round, setID uint64, kr *keystore.Ed25519Keyring, subround subround) []*SignedPrecommit { - just := []*SignedPrecommit{} +func buildTestJustification(t *testing.T, qty int, round, setID uint64, kr *keystore.Ed25519Keyring, subround subround) []*SignedVote { + just := []*SignedVote{} for i := 0; i < qty; i++ { - j := &SignedPrecommit{ + j := &SignedVote{ Vote: NewVote(testHash, uint32(round)), Signature: createSignedVoteMsg(t, uint32(round), round, setID, kr.Keys[i%len(kr.Keys)], subround), AuthorityID: kr.Keys[i%len(kr.Keys)].Public().(*ed25519.PublicKey).AsBytes(), @@ -172,7 +172,7 @@ func TestMessageHandler_VoteMessage(t *testing.T) { gs.state.setID = 99 gs.state.round = 77 v.number = 0x7777 - vm, err := gs.createVoteMessage(v, precommit, gs.keypair) + _, vm, err := gs.createSignedVoteAndVoteMessage(v, precommit) require.NoError(t, err) h := NewMessageHandler(gs, st.Block) @@ -192,6 +192,9 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { gs, st := newTestService(t) h := NewMessageHandler(gs, st.Block) + err := st.Block.AddBlock(testBlock) + require.NoError(t, err) + msg := &NeighbourMessage{ Version: 1, Round: 2, @@ -199,37 +202,26 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { Number: 1, } - _, err := h.handleMessage("", msg) - require.NoError(t, err) - - block := &types.Block{ - Header: &types.Header{ - Number: big.NewInt(1), - ParentHash: st.Block.GenesisHash(), - Digest: types.Digest{ - types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest(), - }, - }, - Body: &types.Body{0}, - } - - err = st.Block.AddBlock(block) + _, err = h.handleMessage("", msg) require.NoError(t, err) out, err := h.handleMessage("", msg) require.NoError(t, err) require.Nil(t, out) - finalised, err := st.Block.GetFinalizedHash(0, 0) - require.NoError(t, err) - require.Equal(t, block.Header.Hash(), finalised) + // check if request for justification was sent out + expected := &testJustificationRequest{ + to: "", + num: 1, + } + require.Equal(t, expected, gs.network.(*testNetwork).justificationRequest) } func TestMessageHandler_VerifyJustification_InvalidSig(t *testing.T) { gs, st := newTestService(t) gs.state.round = 77 - just := &SignedPrecommit{ + just := &SignedVote{ Vote: testVote, Signature: [64]byte{0x1}, AuthorityID: gs.publicKeyBytes(), @@ -286,7 +278,7 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_MinVoteError(t *testing.T func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) { gs, st := newTestService(t) - gs.justification[77] = []*SignedPrecommit{ + gs.justification[77] = []*SignedVote{ { Vote: testVote, Signature: testSignature, @@ -347,7 +339,7 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { err = gs.blockState.(*state.BlockState).SetHeader(testHeader) require.NoError(t, err) - pvj := []*SignedPrecommit{ + pvj := []*SignedVote{ { Vote: testVote, Signature: testSignature, @@ -358,7 +350,7 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { pvjEnc, err := scale.Encode(pvj) require.NoError(t, err) - pcj := []*SignedPrecommit{ + pcj := []*SignedVote{ { Vote: testVote2, Signature: testSignature, @@ -392,7 +384,7 @@ func TestVerifyJustification(t *testing.T) { h := NewMessageHandler(gs, st.Block) vote := NewVote(testHash, 123) - just := &SignedPrecommit{ + just := &SignedVote{ Vote: vote, Signature: createSignedVoteMsg(t, vote.number, 77, gs.state.setID, kr.Alice().(*ed25519.Keypair), precommit), AuthorityID: kr.Alice().Public().(*ed25519.PublicKey).AsBytes(), @@ -407,7 +399,7 @@ func TestVerifyJustification_InvalidSignature(t *testing.T) { h := NewMessageHandler(gs, st.Block) vote := NewVote(testHash, 123) - just := &SignedPrecommit{ + just := &SignedVote{ Vote: vote, // create signed vote with mismatched vote number Signature: createSignedVoteMsg(t, vote.number+1, 77, gs.state.setID, kr.Alice().(*ed25519.Keypair), precommit), @@ -426,7 +418,7 @@ func TestVerifyJustification_InvalidAuthority(t *testing.T) { require.NoError(t, err) vote := NewVote(testHash, 123) - just := &SignedPrecommit{ + just := &SignedVote{ Vote: vote, Signature: createSignedVoteMsg(t, vote.number, 77, gs.state.setID, fakeKey, precommit), AuthorityID: fakeKey.Public().(*ed25519.PublicKey).AsBytes(), diff --git a/lib/grandpa/message_test.go b/lib/grandpa/message_test.go index c94f14579f..004b19da6e 100644 --- a/lib/grandpa/message_test.go +++ b/lib/grandpa/message_test.go @@ -35,7 +35,7 @@ func TestVoteMessageToConsensusMessage(t *testing.T) { v.number = 0x7777 // test precommit - vm, err := gs.createVoteMessage(v, precommit, gs.keypair) + _, vm, err := gs.createSignedVoteAndVoteMessage(v, precommit) require.NoError(t, err) vm.Message.Signature = [64]byte{} @@ -53,7 +53,7 @@ func TestVoteMessageToConsensusMessage(t *testing.T) { require.Equal(t, expected, vm) // test prevote - vm, err = gs.createVoteMessage(v, prevote, gs.keypair) + _, vm, err = gs.createSignedVoteAndVoteMessage(v, prevote) require.NoError(t, err) vm.Message.Signature = [64]byte{} @@ -73,7 +73,7 @@ func TestVoteMessageToConsensusMessage(t *testing.T) { func TestCommitMessageToConsensusMessage(t *testing.T) { gs, _ := newTestService(t) - gs.justification[77] = []*SignedPrecommit{ + gs.justification[77] = []*SignedVote{ { Vote: testVote, Signature: testSignature, @@ -113,7 +113,7 @@ func TestNewCatchUpResponse(t *testing.T) { err = gs.blockState.(*state.BlockState).SetHeader(testHeader) require.NoError(t, err) - pvj := []*SignedPrecommit{ + pvj := []*SignedVote{ { Vote: testVote, Signature: testSignature, @@ -124,7 +124,7 @@ func TestNewCatchUpResponse(t *testing.T) { pvjEnc, err := scale.Encode(pvj) require.NoError(t, err) - pcj := []*SignedPrecommit{ + pcj := []*SignedVote{ { Vote: testVote2, Signature: testSignature, diff --git a/lib/grandpa/message_tracker_test.go b/lib/grandpa/message_tracker_test.go index 62f0ff55d9..fb1f5c0278 100644 --- a/lib/grandpa/message_tracker_test.go +++ b/lib/grandpa/message_tracker_test.go @@ -43,8 +43,10 @@ func TestMessageTracker_ValidateMessage(t *testing.T) { Number: big.NewInt(77), } - msg, err := gs.createVoteMessage(NewVoteFromHeader(fake), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(fake), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, err, ErrBlockDoesNotExist) @@ -69,8 +71,10 @@ func TestMessageTracker_SendMessage(t *testing.T) { Number: big.NewInt(4), } - msg, err := gs.createVoteMessage(NewVoteFromHeader(next), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(next), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, err, ErrBlockDoesNotExist) @@ -107,8 +111,10 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { Number: big.NewInt(4), } - msg, err := gs.createVoteMessage(NewVoteFromHeader(next), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(next), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, ErrBlockDoesNotExist, err) @@ -125,5 +131,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { hash: msg.Message.Hash, number: msg.Message.Number, } - require.Equal(t, expected, gs.prevotes[kr.Alice().Public().(*ed25519.PublicKey).AsBytes()], gs.tracker.messages) + pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes()) + require.True(t, has) + require.Equal(t, expected, pv.(*SignedVote).Vote, gs.tracker.messages) } diff --git a/lib/grandpa/network.go b/lib/grandpa/network.go index e9f8d441e4..81272467c3 100644 --- a/lib/grandpa/network.go +++ b/lib/grandpa/network.go @@ -176,6 +176,18 @@ func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) ( return true, nil } +// sendMessage sends a vote message to be gossiped to the network +func (s *Service) sendMessage(msg GrandpaMessage) error { + cm, err := msg.ToConsensusMessage() + if err != nil { + return err + } + + s.network.SendMessage(cm) + logger.Trace("sent message", "msg", msg) + return nil +} + func (s *Service) sendNeighbourMessage() { t := time.NewTicker(neighbourMessageInterval) defer t.Stop() @@ -220,14 +232,19 @@ func decodeMessage(msg *ConsensusMessage) (m GrandpaMessage, err error) { switch msg.Data[0] { case voteType: - m = &VoteMessage{} - _, err = scale.Decode(msg.Data[1:], m) + r := &bytes.Buffer{} + _, _ = r.Write(msg.Data[1:]) + vm := &VoteMessage{} + err = vm.Decode(r) + m = vm + logger.Trace("got VoteMessage!!!", "msg", m) case commitType: r := &bytes.Buffer{} _, _ = r.Write(msg.Data[1:]) cm := &CommitMessage{} err = cm.Decode(r) m = cm + logger.Trace("got CommitMessage!!!", "msg", m) case neighbourType: mi, err = scale.Decode(msg.Data[1:], &NeighbourMessage{}) if m, ok = mi.(*NeighbourMessage); !ok { diff --git a/lib/grandpa/network_test.go b/lib/grandpa/network_test.go index d85dda093f..0cfeef0238 100644 --- a/lib/grandpa/network_test.go +++ b/lib/grandpa/network_test.go @@ -49,7 +49,7 @@ func TestGrandpaHandshake_Encode(t *testing.T) { func TestHandleNetworkMessage(t *testing.T) { gs, st := newTestService(t) - gs.justification[77] = []*SignedPrecommit{ + gs.justification[77] = []*SignedVote{ { Vote: testVote, Signature: testSignature, diff --git a/lib/grandpa/round_test.go b/lib/grandpa/round_test.go index a4bfe44e91..ec3f416041 100644 --- a/lib/grandpa/round_test.go +++ b/lib/grandpa/round_test.go @@ -37,10 +37,16 @@ import ( var testTimeout = 20 * time.Second +type testJustificationRequest struct { + to peer.ID + num uint32 +} + type testNetwork struct { - t *testing.T - out chan GrandpaMessage - finalised chan GrandpaMessage + t *testing.T + out chan GrandpaMessage + finalised chan GrandpaMessage + justificationRequest *testJustificationRequest } func newTestNetwork(t *testing.T) *testNetwork { @@ -65,7 +71,12 @@ func (n *testNetwork) SendMessage(msg NotificationsMessage) { } } -func (n *testNetwork) SendJustificationRequest(_ peer.ID, _ uint32) {} +func (n *testNetwork) SendJustificationRequest(to peer.ID, num uint32) { + n.justificationRequest = &testJustificationRequest{ + to: to, + num: num, + } +} func (n *testNetwork) RegisterNotificationsProtocol(sub protocol.ID, messageID byte, @@ -79,6 +90,8 @@ func (n *testNetwork) RegisterNotificationsProtocol(sub protocol.ID, return nil } +func (n *testNetwork) SendBlockReqestByHash(_ common.Hash) {} + func onSameChain(blockState BlockState, a, b common.Hash) bool { descendant, err := blockState.IsDescendantOf(a, b) if err != nil { @@ -122,24 +135,31 @@ func TestGrandpa_BaseCase(t *testing.T) { require.NoError(t, err) gss := make([]*Service, len(kr.Keys)) - prevotes := make(map[ed25519.PublicKeyBytes]*Vote) - precommits := make(map[ed25519.PublicKeyBytes]*Vote) + prevotes := new(sync.Map) + precommits := new(sync.Map) for i, gs := range gss { gs, _, _, _ = setupGrandpa(t, kr.Keys[i]) gss[i] = gs state.AddBlocksToState(t, gs.blockState.(*state.BlockState), 15) - prevotes[gs.publicKeyBytes()], err = gs.determinePreVote() + pv, err := gs.determinePreVote() //nolint require.NoError(t, err) + prevotes.Store(gs.publicKeyBytes(), &SignedVote{ + Vote: pv, + }) } for _, gs := range gss { gs.prevotes = prevotes + gs.precommits = precommits } for _, gs := range gss { - precommits[gs.publicKeyBytes()], err = gs.determinePreCommit() + pc, err := gs.determinePreCommit() require.NoError(t, err) + precommits.Store(gs.publicKeyBytes(), &SignedVote{ + Vote: pc, + }) err = gs.finalise() require.NoError(t, err) has, err := gs.blockState.HasJustification(gs.head.Hash()) @@ -164,8 +184,8 @@ func TestGrandpa_DifferentChains(t *testing.T) { require.NoError(t, err) gss := make([]*Service, len(kr.Keys)) - prevotes := make(map[ed25519.PublicKeyBytes]*Vote) - precommits := make(map[ed25519.PublicKeyBytes]*Vote) + prevotes := new(sync.Map) + precommits := new(sync.Map) for i, gs := range gss { gs, _, _, _ = setupGrandpa(t, kr.Keys[i]) @@ -173,23 +193,34 @@ func TestGrandpa_DifferentChains(t *testing.T) { r := rand.Intn(3) state.AddBlocksToState(t, gs.blockState.(*state.BlockState), 4+r) - prevotes[gs.publicKeyBytes()], err = gs.determinePreVote() + pv, err := gs.determinePreVote() //nolint require.NoError(t, err) + prevotes.Store(gs.publicKeyBytes(), &SignedVote{ + Vote: pv, + }) } // only want to add prevotes for a node that has a block that exists on its chain for _, gs := range gss { - for k, pv := range prevotes { + prevotes.Range(func(key, prevote interface{}) bool { + k := key.(ed25519.PublicKeyBytes) + pv := prevote.(*Vote) err = gs.validateVote(pv) if err == nil { - gs.prevotes[k] = pv + gs.prevotes.Store(k, &SignedVote{ + Vote: pv, + }) } - } + return true + }) } for _, gs := range gss { - precommits[gs.publicKeyBytes()], err = gs.determinePreCommit() + pc, err := gs.determinePreCommit() require.NoError(t, err) + precommits.Store(gs.publicKeyBytes(), &SignedVote{ + Vote: pc, + }) err = gs.finalise() require.NoError(t, err) } @@ -449,7 +480,7 @@ func TestPlayGrandpaRound_OneThirdEquivocating(t *testing.T) { vote, err := NewVoteFromHash(leaves[1], gs.blockState) require.NoError(t, err) - vmsg, err := gs.createVoteMessage(vote, prevote, gs.keypair) + _, vmsg, err := gs.createSignedVoteAndVoteMessage(vote, prevote) require.NoError(t, err) for _, in := range ins { diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index b145d816e2..4663bff1cb 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -68,6 +68,7 @@ type DigestHandler interface { // TODO: remove, use GrandpaState // Network is the interface required by GRANDPA for the network type Network interface { SendMessage(msg network.NotificationsMessage) + SendBlockReqestByHash(hash common.Hash) SendJustificationRequest(to peer.ID, num uint32) RegisterNotificationsProtocol(sub protocol.ID, messageID byte, diff --git a/lib/grandpa/types.go b/lib/grandpa/types.go index efa9a468c1..8475984aeb 100644 --- a/lib/grandpa/types.go +++ b/lib/grandpa/types.go @@ -52,20 +52,26 @@ func (s subround) Decode(r io.Reader) (subround, error) { return 255, nil } - if b == 0 { + switch b { + case 0: return prevote, nil - } else if b == 1 { + case 1: return precommit, nil - } else { + case 2: + return primaryProposal, nil + default: return 255, ErrCannotDecodeSubround } } func (s subround) String() string { - if s == prevote { + switch s { + case prevote: return "prevote" - } else if s == precommit { + case precommit: return "precommit" + case primaryProposal: + return "primaryProposal" } return "unknown" @@ -188,45 +194,53 @@ func (v *Vote) String() string { return fmt.Sprintf("hash=%s number=%d", v.hash, v.number) } -// SignedPrecommit represents a signed precommit message for a finalised block -type SignedPrecommit struct { +// SignedVote represents a signed precommit message for a finalised block +type SignedVote struct { Vote *Vote Signature [64]byte AuthorityID ed25519.PublicKeyBytes } +func (s *SignedVote) String() string { + return fmt.Sprintf("SignedVote hash=%s number=%d authority=%s", + s.Vote.hash, + s.Vote.number, + s.AuthorityID, + ) +} + // Encode returns the SCALE encoded Justification -func (j *SignedPrecommit) Encode() ([]byte, error) { - enc, err := j.Vote.Encode() +func (s *SignedVote) Encode() ([]byte, error) { + enc, err := s.Vote.Encode() if err != nil { return nil, err } - enc = append(enc, j.Signature[:]...) - enc = append(enc, j.AuthorityID[:]...) + enc = append(enc, s.Signature[:]...) + enc = append(enc, s.AuthorityID[:]...) return enc, nil } // Decode returns the SCALE decoded Justification -func (j *SignedPrecommit) Decode(r io.Reader) (*SignedPrecommit, error) { +func (s *SignedVote) Decode(r io.Reader) (*SignedVote, error) { sd := &scale.Decoder{Reader: r} - i, err := sd.Decode(j) + i, err := sd.Decode(s) if err != nil { return nil, err } - d := i.(*SignedPrecommit) - j.Vote = d.Vote - j.Signature = d.Signature - j.AuthorityID = d.AuthorityID - return j, nil + d := i.(*SignedVote) + s.Vote = d.Vote + s.Signature = d.Signature + s.AuthorityID = d.AuthorityID + return s, nil } // Commit contains all the signed precommits for a given block type Commit struct { Hash common.Hash Number uint32 - Precommits []*SignedPrecommit + Precommits []*SignedVote } // Justification represents a finality justification for a block @@ -235,7 +249,7 @@ type Justification struct { Commit *Commit } -func newJustification(round uint64, hash common.Hash, number uint32, j []*SignedPrecommit) *Justification { +func newJustification(round uint64, hash common.Hash, number uint32, j []*SignedVote) *Justification { return &Justification{ Round: round, Commit: &Commit{ diff --git a/lib/grandpa/types_test.go b/lib/grandpa/types_test.go index c88d68b866..e1d882f07a 100644 --- a/lib/grandpa/types_test.go +++ b/lib/grandpa/types_test.go @@ -38,8 +38,8 @@ func TestPubkeyToVoter(t *testing.T) { require.Equal(t, voters[0], voter) } -func TestSignedPrecommitEncoding(t *testing.T) { - just := &SignedPrecommit{ +func TestSignedVoteEncoding(t *testing.T) { + just := &SignedVote{ Vote: testVote, Signature: testSignature, AuthorityID: testAuthorityID, @@ -50,14 +50,14 @@ func TestSignedPrecommitEncoding(t *testing.T) { rw := &bytes.Buffer{} rw.Write(enc) - dec := new(SignedPrecommit) + dec := new(SignedVote) _, err = dec.Decode(rw) require.NoError(t, err) require.Equal(t, just, dec) } -func TestSignedPrecommitArrayEncoding(t *testing.T) { - just := []*SignedPrecommit{ +func TestSignedVoteArrayEncoding(t *testing.T) { + just := []*SignedVote{ { Vote: testVote, Signature: testSignature, @@ -68,13 +68,13 @@ func TestSignedPrecommitArrayEncoding(t *testing.T) { enc, err := scale.Encode(just) require.NoError(t, err) - dec, err := scale.Decode(enc, make([]*SignedPrecommit, 1)) + dec, err := scale.Decode(enc, make([]*SignedVote, 1)) require.NoError(t, err) - require.Equal(t, just, dec.([]*SignedPrecommit)) + require.Equal(t, just, dec.([]*SignedVote)) } func TestJustification(t *testing.T) { - just := &SignedPrecommit{ + just := &SignedVote{ Vote: testVote, Signature: testSignature, AuthorityID: testAuthorityID, @@ -83,7 +83,7 @@ func TestJustification(t *testing.T) { fj := &Justification{ Round: 99, Commit: &Commit{ - Precommits: []*SignedPrecommit{just}, + Precommits: []*SignedVote{just}, }, } enc, err := fj.Encode() diff --git a/lib/grandpa/vote_message.go b/lib/grandpa/vote_message.go index 641eb4d7c6..41578f333d 100644 --- a/lib/grandpa/vote_message.go +++ b/lib/grandpa/vote_message.go @@ -22,7 +22,7 @@ import ( "errors" "time" - "github.com/ChainSafe/gossamer/lib/crypto" + "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/crypto/ed25519" "github.com/ChainSafe/gossamer/lib/scale" ) @@ -47,7 +47,7 @@ func (s *Service) receiveMessages(cond func() bool) { v, err := s.validateMessage(vm) if err != nil { - logger.Trace("failed to validate vote message", "message", vm, "error", err) + logger.Debug("failed to validate vote message", "message", vm, "error", err) continue } @@ -55,8 +55,9 @@ func (s *Service) receiveMessages(cond func() bool) { "vote", v, "round", vm.Round, "subround", vm.Message.Stage, - "prevotes", s.prevotes, - "precommits", s.precommits, + "prevote count", s.lenVotes(prevote), + "precommit count", s.lenVotes(precommit), + "votes needed", s.state.threshold(), ) case <-ctx.Done(): logger.Trace("returning from receiveMessages") @@ -74,34 +75,7 @@ func (s *Service) receiveMessages(cond func() bool) { } } -// sendMessage sends a message through the out channel -func (s *Service) sendMessage(vote *Vote, stage subround) error { - msg, err := s.createVoteMessage(vote, stage, s.keypair) - if err != nil { - return err - } - - cm, err := msg.ToConsensusMessage() - if err != nil { - return err - } - - s.chanLock.Lock() - defer s.chanLock.Unlock() - - // context was canceled - if s.ctx.Err() != nil { - return nil - } - - s.network.SendMessage(cm) - logger.Trace("sent VoteMessage", "msg", msg) - - return nil -} - -// createVoteMessage returns a signed VoteMessage given a header -func (s *Service) createVoteMessage(vote *Vote, stage subround, kp crypto.Keypair) (*VoteMessage, error) { +func (s *Service) createSignedVoteAndVoteMessage(vote *Vote, stage subround) (*SignedVote, *VoteMessage, error) { msg, err := scale.Encode(&FullVote{ Stage: stage, Vote: vote, @@ -109,27 +83,35 @@ func (s *Service) createVoteMessage(vote *Vote, stage subround, kp crypto.Keypai SetID: s.state.setID, }) if err != nil { - return nil, err + return nil, nil, err } - sig, err := kp.Sign(msg) + sig, err := s.keypair.Sign(msg) if err != nil { - return nil, err + return nil, nil, err + } + + pc := &SignedVote{ + Vote: vote, + Signature: ed25519.NewSignatureBytes(sig), + AuthorityID: s.keypair.Public().(*ed25519.PublicKey).AsBytes(), } sm := &SignedMessage{ Stage: stage, - Hash: vote.hash, - Number: vote.number, + Hash: pc.Vote.hash, + Number: pc.Vote.number, Signature: ed25519.NewSignatureBytes(sig), - AuthorityID: kp.Public().(*ed25519.PublicKey).AsBytes(), + AuthorityID: s.keypair.Public().(*ed25519.PublicKey).AsBytes(), } - return &VoteMessage{ + vm := &VoteMessage{ Round: s.state.round, SetID: s.state.setID, Message: sm, - }, nil + } + + return pc, vm, nil } // validateMessage validates a VoteMessage and adds it to the current votes @@ -149,6 +131,19 @@ func (s *Service) validateMessage(m *VoteMessage) (*Vote, error) { return nil, err } + switch m.Message.Stage { + case prevote, primaryProposal: + pv, has := s.loadVote(pk.AsBytes(), prevote) + if has && pv.Vote.hash.Equal(m.Message.Hash) { + return nil, errVoteExists + } + case precommit: + pc, has := s.loadVote(pk.AsBytes(), precommit) + if has && pc.Vote.hash.Equal(m.Message.Hash) { + return nil, errVoteExists + } + } + err = validateMessageSignature(pk, m) if err != nil { return nil, err @@ -161,6 +156,24 @@ func (s *Service) validateMessage(m *VoteMessage) (*Vote, error) { // check that vote is for current round if m.Round != s.state.round { + if m.Round < s.state.round { + // peer doesn't know round was finalised, send out another commit message + header, err := s.blockState.GetFinalizedHeader(m.Round, m.SetID) //nolint + if err != nil { + return nil, err + } + + // send finalised block from previous round to network + msg, err := s.newCommitMessage(header, m.Round).ToConsensusMessage() + if err != nil { + return nil, err + } + + // TODO: don't broadcast, just send to peer; will address in a follow-up + s.network.SendMessage(msg) + } + + // TODO: get justification if your round is lower, or just do catch-up? return nil, ErrRoundMismatch } @@ -179,38 +192,32 @@ func (s *Service) validateMessage(m *VoteMessage) (*Vote, error) { } err = s.validateVote(vote) - if err == ErrBlockDoesNotExist { + if errors.Is(err, ErrBlockDoesNotExist) || errors.Is(err, blocktree.ErrEndNodeNotFound) { + // TODO: cancel if block is imported; if we refactor the syncing this will likely become cleaner + // as we can have an API to synchronously sync and import a block + go s.network.SendBlockReqestByHash(vote.hash) s.tracker.add(m) } if err != nil { return nil, err } - s.mapLock.Lock() - defer s.mapLock.Unlock() - - just := &SignedPrecommit{ + just := &SignedVote{ Vote: vote, Signature: m.Message.Signature, AuthorityID: pk.AsBytes(), } - // add justification before checking for equivocation, since equivocatory vote may still be used in justification - if m.Message.Stage == prevote { - s.pvJustifications[m.Message.Hash] = append(s.pvJustifications[m.Message.Hash], just) - } else if m.Message.Stage == precommit { - s.pcJustifications[m.Message.Hash] = append(s.pcJustifications[m.Message.Hash], just) - } - - equivocated := s.checkForEquivocation(voter, vote, m.Message.Stage) + equivocated := s.checkForEquivocation(voter, just, m.Message.Stage) if equivocated { return nil, ErrEquivocation } - if m.Message.Stage == prevote { - s.prevotes[pk.AsBytes()] = vote - } else if m.Message.Stage == precommit { - s.precommits[pk.AsBytes()] = vote + switch m.Message.Stage { + case prevote, primaryProposal: + s.prevotes.Store(pk.AsBytes(), just) + case precommit: + s.precommits.Store(pk.AsBytes(), just) } return vote, nil @@ -219,31 +226,38 @@ func (s *Service) validateMessage(m *VoteMessage) (*Vote, error) { // checkForEquivocation checks if the vote is an equivocatory vote. // it returns true if so, false otherwise. // additionally, if the vote is equivocatory, it updates the service's votes and equivocations. -func (s *Service) checkForEquivocation(voter *Voter, vote *Vote, stage subround) bool { +func (s *Service) checkForEquivocation(voter *Voter, vote *SignedVote, stage subround) bool { v := voter.Key.AsBytes() - var eq map[ed25519.PublicKeyBytes][]*Vote - var votes map[ed25519.PublicKeyBytes]*Vote + // save justification, since equivocatory vote may still be used in justification + var eq map[ed25519.PublicKeyBytes][]*SignedVote - if stage == prevote { + switch stage { + case prevote, primaryProposal: eq = s.pvEquivocations - votes = s.prevotes - } else { + case precommit: eq = s.pcEquivocations - votes = s.precommits } - if eq[v] != nil { + s.mapLock.Lock() + defer s.mapLock.Unlock() + + _, has := eq[v] + if has { // if the voter has already equivocated, every vote in that round is an equivocatory vote eq[v] = append(eq[v], vote) return true } - if votes[v] != nil && votes[v].hash != vote.hash { + existingVote, has := s.loadVote(v, stage) + if !has { + return false + } + + if has && existingVote.Vote.hash != vote.Vote.hash { // the voter has already voted, all their votes are now equivocatory - prev := votes[v] - eq[v] = []*Vote{prev, vote} - delete(votes, v) + eq[v] = []*SignedVote{existingVote, vote} + s.deleteVote(v, stage) return true } diff --git a/lib/grandpa/vote_message_test.go b/lib/grandpa/vote_message_test.go index b24b728923..530719c611 100644 --- a/lib/grandpa/vote_message_test.go +++ b/lib/grandpa/vote_message_test.go @@ -55,7 +55,9 @@ func TestCheckForEquivocation_NoEquivocation(t *testing.T) { require.NoError(t, err) for _, v := range voters { - equivocated := gs.checkForEquivocation(v, vote, prevote) + equivocated := gs.checkForEquivocation(v, &SignedVote{ + Vote: vote, + }, prevote) require.False(t, equivocated) } } @@ -89,15 +91,19 @@ func TestCheckForEquivocation_WithEquivocation(t *testing.T) { voter := voters[0] - gs.prevotes[voter.Key.AsBytes()] = vote1 + gs.prevotes.Store(voter.Key.AsBytes(), &SignedVote{ + Vote: vote1, + }) vote2, err := NewVoteFromHash(leaves[1], st.Block) require.NoError(t, err) - equivocated := gs.checkForEquivocation(voter, vote2, prevote) + equivocated := gs.checkForEquivocation(voter, &SignedVote{ + Vote: vote2, + }, prevote) require.True(t, equivocated) - require.Equal(t, 0, len(gs.prevotes)) + require.Equal(t, 0, gs.lenVotes(prevote)) require.Equal(t, 1, len(gs.pvEquivocations)) require.Equal(t, 2, len(gs.pvEquivocations[voter.Key.AsBytes()])) } @@ -137,24 +143,30 @@ func TestCheckForEquivocation_WithExistingEquivocation(t *testing.T) { voter := voters[0] - gs.prevotes[voter.Key.AsBytes()] = vote + gs.prevotes.Store(voter.Key.AsBytes(), &SignedVote{ + Vote: vote, + }) vote2 := NewVoteFromHeader(branches[0]) require.NoError(t, err) - equivocated := gs.checkForEquivocation(voter, vote2, prevote) + equivocated := gs.checkForEquivocation(voter, &SignedVote{ + Vote: vote2, + }, prevote) require.True(t, equivocated) - require.Equal(t, 0, len(gs.prevotes)) + require.Equal(t, 0, gs.lenVotes(prevote)) require.Equal(t, 1, len(gs.pvEquivocations)) vote3 := NewVoteFromHeader(branches[1]) require.NoError(t, err) - equivocated = gs.checkForEquivocation(voter, vote3, prevote) + equivocated = gs.checkForEquivocation(voter, &SignedVote{ + Vote: vote3, + }, prevote) require.True(t, equivocated) - require.Equal(t, 0, len(gs.prevotes)) + require.Equal(t, 0, gs.lenVotes(prevote)) require.Equal(t, 1, len(gs.pvEquivocations)) require.Equal(t, 3, len(gs.pvEquivocations[voter.Key.AsBytes()])) } @@ -182,8 +194,10 @@ func TestValidateMessage_Valid(t *testing.T) { h, err := st.Block.BestBlockHeader() require.NoError(t, err) - msg, err := gs.createVoteMessage(NewVoteFromHeader(h), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(h), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) vote, err := gs.validateMessage(msg) require.NoError(t, err) @@ -213,8 +227,10 @@ func TestValidateMessage_InvalidSignature(t *testing.T) { h, err := st.Block.BestBlockHeader() require.NoError(t, err) - msg, err := gs.createVoteMessage(NewVoteFromHeader(h), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(h), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) msg.Message.Signature[63] = 0 @@ -244,8 +260,10 @@ func TestValidateMessage_SetIDMismatch(t *testing.T) { h, err := st.Block.BestBlockHeader() require.NoError(t, err) - msg, err := gs.createVoteMessage(NewVoteFromHeader(h), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(h), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) gs.state.setID = 1 @@ -280,18 +298,22 @@ func TestValidateMessage_Equivocation(t *testing.T) { } } - h, err := st.Block.BestBlockHeader() + leaves := gs.blockState.Leaves() + voteA, err := NewVoteFromHash(leaves[0], st.Block) require.NoError(t, err) - - vote := NewVoteFromHeader(h) + voteB, err := NewVoteFromHash(leaves[1], st.Block) require.NoError(t, err) voter := voters[0] - gs.prevotes[voter.Key.AsBytes()] = vote + gs.prevotes.Store(voter.Key.AsBytes(), &SignedVote{ + Vote: voteA, + }) - msg, err := gs.createVoteMessage(NewVoteFromHeader(branches[0]), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(voteB, prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, ErrEquivocation, err, gs.prevotes) @@ -323,8 +345,10 @@ func TestValidateMessage_BlockDoesNotExist(t *testing.T) { Number: big.NewInt(77), } - msg, err := gs.createVoteMessage(NewVoteFromHeader(fake), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(fake), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, err, ErrBlockDoesNotExist) @@ -361,8 +385,10 @@ func TestValidateMessage_IsNotDescendant(t *testing.T) { require.NoError(t, err) gs.head = h - msg, err := gs.createVoteMessage(NewVoteFromHeader(branches[0]), prevote, kr.Alice()) + gs.keypair = kr.Alice().(*ed25519.Keypair) + _, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(branches[0]), prevote) require.NoError(t, err) + gs.keypair = kr.Bob().(*ed25519.Keypair) _, err = gs.validateMessage(msg) require.Equal(t, ErrDescendantNotFound, err, gs.prevotes)