Skip to content

Commit

Permalink
fix(lib/grandpa): fix threshold checking to be strictly greater than …
Browse files Browse the repository at this point in the history
  • Loading branch information
noot authored and timwu20 committed Dec 6, 2021
1 parent 165903d commit ed51f97
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 182 deletions.
5 changes: 5 additions & 0 deletions dot/state/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@ func AddBlocksToStateWithFixedBranches(t *testing.T, blockState *BlockState, dep
// create base tree
startNum := int(head.Number.Int64())
for i := startNum + 1; i <= depth; i++ {
d := types.NewBabePrimaryPreDigest(0, uint64(i), [32]byte{}, [64]byte{})
digest := types.NewDigest()
_ = digest.Add(*d.ToPreRuntimeDigest())

block := &types.Block{
Header: types.Header{
ParentHash: previousHash,
Number: big.NewInt(int64(i)),
StateRoot: trie.EmptyHash,
Digest: digest,
},
Body: types.Body{},
}
Expand Down
98 changes: 42 additions & 56 deletions lib/grandpa/grandpa.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func NewService(cfg *Config) (*Service, error) {
}

s.messageHandler = NewMessageHandler(s, s.blockState)
s.tracker = newTracker(s.blockState, s.messageHandler)
s.paused.Store(false)
return s, nil
}
Expand All @@ -187,6 +188,8 @@ func (s *Service) Start() error {
return nil
}

s.tracker.start()

go func() {
if err := s.initiate(); err != nil {
logger.Crit("failed to initiate", "error", err)
Expand All @@ -203,7 +206,6 @@ func (s *Service) Stop() error {
defer s.chanLock.Unlock()

s.cancel()

s.blockState.FreeFinalisedNotifierChannel(s.finalisedCh)

if !s.authority {
Expand Down Expand Up @@ -256,13 +258,14 @@ func (s *Service) updateAuthorities() error {

s.state.voters = nextAuthorities
s.state.setID = currSetID
s.state.round = 1 // round resets to 1 after a set ID change
s.state.round = 0 // round resets to 1 after a set ID change, setting to 0 before incrementing indicates the setID has been increased
return nil
}

func (s *Service) publicKeyBytes() ed25519.PublicKeyBytes {
return s.keypair.Public().(*ed25519.PublicKey).AsBytes()
}

func (s *Service) initiateRound() error {
// if there is an authority change, execute it
err := s.updateAuthorities()
Expand Down Expand Up @@ -297,7 +300,7 @@ func (s *Service) initiateRound() error {
}

// there was a setID change, or the node was started from genesis
if s.state.round == 1 {
if s.state.round == 0 {
s.chanLock.Lock()
s.mapLock.Lock()
s.preVotedBlock[0] = NewVoteFromHeader(s.head)
Expand All @@ -310,17 +313,10 @@ func (s *Service) initiateRound() error {
s.roundLock.Lock()
s.state.round++
logger.Debug("incrementing grandpa round", "next round", s.state.round)
if s.tracker != nil {
s.tracker.stop()
}

s.prevotes = new(sync.Map)
s.precommits = new(sync.Map)
s.pvEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote)
s.pcEquivocations = make(map[ed25519.PublicKeyBytes][]*SignedVote)
s.tracker = newTracker(s.blockState, s.messageHandler)
s.tracker.start()
logger.Trace("started message tracker")
s.roundLock.Unlock()

best, err := s.blockState.BestBlockHeader()
Expand All @@ -333,7 +329,8 @@ func (s *Service) initiateRound() error {
}

// don't begin grandpa until we are at block 1
return s.waitForFirstBlock()
s.waitForFirstBlock()
return nil
}

// initiate initates the grandpa service to begin voting in sequential rounds
Expand Down Expand Up @@ -364,30 +361,21 @@ func (s *Service) initiate() error {
}
}

func (s *Service) waitForFirstBlock() error {
func (s *Service) waitForFirstBlock() {
ch := s.blockState.GetImportedBlockNotifierChannel()

defer s.blockState.FreeImportedBlockNotifierChannel(ch)

// loop until block 1
for {
done := false

select {
case block := <-ch:
if block != nil && block.Header.Number.Int64() > 0 {
done = true
return
}
case <-s.ctx.Done():
return nil
}

if done {
break
return
}
}

return nil
}

func (s *Service) handleIsPrimary() (bool, error) {
Expand Down Expand Up @@ -452,6 +440,8 @@ func (s *Service) primaryBroadcastCommitMessage() {
// at the end of this round, a block will be finalised.
func (s *Service) playGrandpaRound() error {
logger.Debug("starting round", "round", s.state.round, "setID", s.state.setID)
start := time.Now()

ctx, cancel := context.WithCancel(s.ctx)
defer cancel()

Expand Down Expand Up @@ -514,12 +504,12 @@ func (s *Service) playGrandpaRound() error {
// continue to send precommit messages until round is done
go s.sendVoteMessage(precommit, pcm, roundComplete)

err = s.attemptToFinalize()
if err != nil {
if err = s.attemptToFinalize(); err != nil {
logger.Error("failed to finalise", "error", err)
return err
}

logger.Debug("round completed", "duration", time.Since(start))
return nil
}

Expand All @@ -532,8 +522,7 @@ func (s *Service) sendVoteMessage(stage Subround, msg *VoteMessage, roundComplet
return
}

err := s.sendMessage(msg)
if err != nil {
if err := s.sendMessage(msg); err != nil {
logger.Warn("could not send message", "stage", stage, "error", err)
}

Expand Down Expand Up @@ -588,12 +577,11 @@ func (s *Service) attemptToFinalize() error {
return err
}

if bfc.Number < uint32(s.head.Number.Int64()) || pc < s.state.threshold() {
if bfc.Number < uint32(s.head.Number.Int64()) || pc <= s.state.threshold() {
continue
}

err = s.finalise()
if err != nil {
if err = s.finalise(); err != nil {
return err
}

Expand Down Expand Up @@ -749,7 +737,7 @@ func (s *Service) isFinalisable(round uint64) (bool, error) {
return false, errors.New("cannot find best final candidate for previous round")
}

if bfc.Number <= pvb.Number && (s.state.round == 0 || prevBfc.Number <= bfc.Number) && pc >= s.state.threshold() {
if bfc.Number <= pvb.Number && (s.state.round == 0 || prevBfc.Number <= bfc.Number) && pc > s.state.threshold() {
return true, nil
}

Expand Down Expand Up @@ -867,20 +855,20 @@ func (s *Service) derivePrimary() Voter {
}

// getBestFinalCandidate calculates the set of blocks that are less than or equal to the pre-voted block in height,
// with >= 2/3 pre-commit votes, then returns the block with the highest number from this set.
// with >2/3 pre-commit votes, then returns the block with the highest number from this set.
func (s *Service) getBestFinalCandidate() (*Vote, error) {
prevoted, err := s.getPreVotedBlock()
if err != nil {
return nil, err
}

// get all blocks with >=2/3 pre-commits
// get all blocks with >2/3 pre-commits
blocks, err := s.getPossibleSelectedBlocks(precommit, s.state.threshold())
if err != nil {
return nil, err
}

// if there are no blocks with >=2/3 pre-commits, just return the pre-voted block
// if there are no blocks with >2/3 pre-commits, just return the pre-voted block
// TODO: is this correct? the spec implies that it should return nil, but discussions have suggested
// that we return the prevoted block. (#1815)
if len(blocks) == 0 {
Expand All @@ -890,6 +878,7 @@ func (s *Service) getBestFinalCandidate() (*Vote, error) {
// if there are multiple blocks, get the one with the highest number
// that is also an ancestor of the prevoted block (or is the prevoted block)
bfc := &Vote{
Hash: s.blockState.GenesisHash(),
Number: 0,
}

Expand All @@ -901,7 +890,7 @@ func (s *Service) getBestFinalCandidate() (*Vote, error) {
}

if !isDescendant {
// find common ancestor, implicitly has >=2/3 votes
// find common ancestor, implicitly has >2/3 votes
pred, err := s.blockState.HighestCommonAncestor(h, prevoted.Hash)
if err != nil {
return nil, err
Expand All @@ -925,17 +914,13 @@ func (s *Service) getBestFinalCandidate() (*Vote, error) {
}
}

if [32]byte(bfc.Hash) == [32]byte{} {
return &prevoted, nil
}

return bfc, nil
}

// isCompletable returns true if the round is completable, false otherwise
func (s *Service) isCompletable() (bool, error) {
// haven't received enough votes, not completable
if uint64(s.lenVotes(precommit)+len(s.pcEquivocations)) < s.state.threshold() {
if uint64(s.lenVotes(precommit)+len(s.pcEquivocations)) <= s.state.threshold() {
return false, nil
}

Expand All @@ -947,7 +932,7 @@ func (s *Service) isCompletable() (bool, error) {
votes := s.getVotes(precommit)

// for each block with at least 1 vote that's a descendant of the prevoted block,
// check that (total precommits - total pc equivocations - precommits for that block) >= 2/3 |V|
// check that (total precommits - total pc equivocations - precommits for that block) >2/3|V|
// ie. there must not be a descendent of the prevotes block that is preferred
for _, v := range votes {
if prevoted.Hash == v.Hash {
Expand All @@ -969,7 +954,7 @@ func (s *Service) isCompletable() (bool, error) {
return false, err
}

if uint64(len(votes)-len(s.pcEquivocations))-c < s.state.threshold() {
if uint64(len(votes)-len(s.pcEquivocations))-c <= s.state.threshold() {
// round isn't completable
return false, nil
}
Expand All @@ -980,7 +965,7 @@ func (s *Service) isCompletable() (bool, error) {

// getPreVotedBlock returns the current pre-voted block B. also known as GRANDPA-GHOST.
// the pre-voted block is the block with the highest block number in the set of all the blocks with
// total votes >= 2/3 the total number of voters, where the total votes is determined by getTotalVotesForBlock.
// total votes >2/3 the total number of voters, where the total votes is determined by getTotalVotesForBlock.
func (s *Service) getPreVotedBlock() (Vote, error) {
blocks, err := s.getPossibleSelectedBlocks(prevote, s.state.threshold())
if err != nil {
Expand Down Expand Up @@ -1035,10 +1020,11 @@ func (s *Service) getGrandpaGHOST() (Vote, error) {
return Vote{}, err
}

threshold--
if len(blocks) > 0 || threshold == 0 {
break
}

threshold--
}

if len(blocks) == 0 {
Expand All @@ -1062,36 +1048,36 @@ func (s *Service) getGrandpaGHOST() (Vote, error) {
return highest, nil
}

// getPossibleSelectedBlocks returns blocks with total votes >=threshold in a map of block hash -> block number.
// if there are no blocks that have >=threshold direct votes, this function will find ancestors of those blocks that do have >=threshold votes.
// getPossibleSelectedBlocks returns blocks with total votes >threshold in a map of block hash -> block number.
// if there are no blocks that have >threshold direct votes, this function will find ancestors of those blocks that do have >threshold votes.
// note that by voting for a block, all of its ancestor blocks are automatically voted for.
// thus, if there are no blocks with >=threshold total votes, but the sum of votes for blocks A and B is >=threshold, then this function returns
// thus, if there are no blocks with >threshold total votes, but the sum of votes for blocks A and B is >threshold, then this function returns
// the first common ancestor of A and B.
// in general, this function will return the highest block on each chain with >=threshold votes.
// in general, this function will return the highest block on each chain with >threshold votes.
func (s *Service) getPossibleSelectedBlocks(stage Subround, threshold uint64) (map[common.Hash]uint32, error) {
// get blocks that were directly voted for
votes := s.getDirectVotes(stage)
blocks := make(map[common.Hash]uint32)

// check if any of them have >=threshold votes
// check if any of them have >threshold votes
for v := range votes {
total, err := s.getTotalVotesForBlock(v.Hash, stage)
if err != nil {
return nil, err
}

if total >= threshold {
if total > threshold {
blocks[v.Hash] = v.Number
}
}

// since we want to select the block with the highest number that has >=threshold votes,
// since we want to select the block with the highest number that has >threshold votes,
// we can return here since their ancestors won't have a higher number.
if len(blocks) != 0 {
return blocks, nil
}

// no block has >=threshold direct votes, check for votes for ancestors recursively
// no block has >threshold direct votes, check for votes for ancestors recursively
var err error
va := s.getVotes(stage)

Expand All @@ -1105,15 +1091,15 @@ func (s *Service) getPossibleSelectedBlocks(stage Subround, threshold uint64) (m
return blocks, nil
}

// getPossibleSelectedAncestors recursively searches for ancestors with >=2/3 votes
// it returns a map of block hash -> number, such that the blocks in the map have >=2/3 votes
// getPossibleSelectedAncestors recursively searches for ancestors with >2/3 votes
// it returns a map of block hash -> number, such that the blocks in the map have >2/3 votes
func (s *Service) getPossibleSelectedAncestors(votes []Vote, curr common.Hash, selected map[common.Hash]uint32, stage Subround, threshold uint64) (map[common.Hash]uint32, error) {
for _, v := range votes {
if v.Hash == curr {
continue
}

// find common ancestor, check if votes for it is >=threshold or not
// find common ancestor, check if votes for it is >threshold or not
pred, err := s.blockState.HighestCommonAncestor(v.Hash, curr)
if err == blocktree.ErrNodeNotFound {
continue
Expand All @@ -1130,7 +1116,7 @@ func (s *Service) getPossibleSelectedAncestors(votes []Vote, curr common.Hash, s
return nil, err
}

if total >= threshold {
if total > threshold {
var h *types.Header
h, err = s.blockState.GetHeader(pred)
if err != nil {
Expand Down
Loading

0 comments on commit ed51f97

Please sign in to comment.