Skip to content

Commit

Permalink
Consolidate istanbul.Message encoding/decoding (#1592)
Browse files Browse the repository at this point in the history
Message decoding and encoding, prevously happend in 2 steps. This commit
makes decoding and encoding a message an atomic operation, either it
succeeds or it fails.

This allows us to decode in just one place in the code, allowing us to
have a consistent approach to decode errors. By being able to respond to
decode errors at a higer level in the stack, we remove the need to pass
errors back up the stack to indicate invalid messages. It also helps
separate encoding and decoding logic from our domain logic.
  • Loading branch information
piersy authored Jul 12, 2021
1 parent 76f5918 commit 968e504
Show file tree
Hide file tree
Showing 26 changed files with 841 additions and 1,170 deletions.
223 changes: 49 additions & 174 deletions consensus/istanbul/backend/announce.go

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions consensus/istanbul/backend/announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ func TestAnnounceGossipQueryMsg(t *testing.T) {

engine0Enode := engine0.SelfNode()

w1 := engine1.wallets()
// Create version certificate messages for engine1 and engine2, so that engine0 will send a queryEnodeMessage to them
vCert1, err := generateVersionCertificate(w1.Ecdsa.Address, w1.Ecdsa.PublicKey, engine1AnnounceVersion, w1.Ecdsa.Sign)
vCert1, err := istanbul.NewVersionCertificate(engine1AnnounceVersion, engine1.Sign)
if err != nil {
t.Errorf("Error in generating version certificate for engine1. Error: %v", err)
}
w2 := engine2.wallets()
vCert2, err := generateVersionCertificate(w2.Ecdsa.Address, w2.Ecdsa.PublicKey, engine2AnnounceVersion, w2.Ecdsa.Sign)

vCert2, err := istanbul.NewVersionCertificate(engine1AnnounceVersion, engine2.Sign)
if err != nil {
t.Errorf("Error in generating version certificate for engine2. Error: %v", err)
}

// Have engine0 handle vCert messages from engine1 and engine2
vCert1MsgPayload, err := encodeVersionCertificatesMsg([]*versionCertificate{vCert1})

vCert1MsgPayload, err := istanbul.NewVersionCeritifcatesMessage([]*istanbul.VersionCertificate{vCert1}, engine1Address).Payload()
if err != nil {
t.Errorf("Error in encoding vCert1. Error: %v", err)
}
Expand All @@ -58,7 +58,7 @@ func TestAnnounceGossipQueryMsg(t *testing.T) {
t.Errorf("Error in handling vCert1. Error: %v", err)
}

vCert2MsgPayload, err := encodeVersionCertificatesMsg([]*versionCertificate{vCert2})
vCert2MsgPayload, err := istanbul.NewVersionCeritifcatesMessage([]*istanbul.VersionCertificate{vCert2}, engine2Address).Payload()
if err != nil {
t.Errorf("Error in encoding vCert2. Error: %v", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@
package enodes

import (
"crypto/ecdsa"
"encoding/hex"
"fmt"
"io"
"strings"

"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt"

"github.com/celo-org/celo-blockchain/common"
"github.com/celo-org/celo-blockchain/consensus/istanbul"
"github.com/celo-org/celo-blockchain/consensus/istanbul/backend/internal/db"
"github.com/celo-org/celo-blockchain/crypto"
"github.com/celo-org/celo-blockchain/log"
"github.com/celo-org/celo-blockchain/rlp"
)
Expand All @@ -43,55 +40,14 @@ type VersionCertificateDB struct {
logger log.Logger
}

// VersionCertificateEntry is an entry in the VersionCertificateDB.
// It's a signed message from a registered or active validator indicating
// the most recent version of its enode.
type VersionCertificateEntry struct {
Address common.Address
PublicKey *ecdsa.PublicKey
Version uint
Signature []byte
}

func versionCertificateEntryFromGenericEntry(entry db.GenericEntry) (*VersionCertificateEntry, error) {
signedAnnVersionEntry, ok := entry.(*VersionCertificateEntry)
func versionCertificateEntryFromGenericEntry(entry db.GenericEntry) (*istanbul.VersionCertificate, error) {
signedAnnVersionEntry, ok := entry.(*istanbul.VersionCertificate)
if !ok {
return nil, errIncorrectEntryType
}
return signedAnnVersionEntry, nil
}

// EncodeRLP serializes VersionCertificateEntry into the Ethereum RLP format.
func (entry *VersionCertificateEntry) EncodeRLP(w io.Writer) error {
encodedPublicKey := crypto.FromECDSAPub(entry.PublicKey)
return rlp.Encode(w, []interface{}{entry.Address, encodedPublicKey, entry.Version, entry.Signature})
}

// DecodeRLP implements rlp.Decoder, and load the VersionCertificateEntry fields from a RLP stream.
func (entry *VersionCertificateEntry) DecodeRLP(s *rlp.Stream) error {
var content struct {
Address common.Address
PublicKey []byte
Version uint
Signature []byte
}

if err := s.Decode(&content); err != nil {
return err
}
decodedPublicKey, err := crypto.UnmarshalPubkey(content.PublicKey)
if err != nil {
return err
}
entry.Address, entry.PublicKey, entry.Version, entry.Signature = content.Address, decodedPublicKey, content.Version, content.Signature
return nil
}

// String gives a string representation of VersionCertificateEntry
func (entry *VersionCertificateEntry) String() string {
return fmt.Sprintf("{Address: %v, Version: %v, Signature: %v}", entry.Address, entry.Version, hex.EncodeToString(entry.Signature))
}

// OpenVersionCertificateDB opens a signed announce version database for storing
// VersionCertificates. If no path is given an in-memory, temporary database is constructed.
func OpenVersionCertificateDB(path string) (*VersionCertificateDB, error) {
Expand Down Expand Up @@ -119,7 +75,7 @@ func (svdb *VersionCertificateDB) String() string {
var b strings.Builder
b.WriteString("VersionCertificateDB:")

err := svdb.iterate(func(address common.Address, entry *VersionCertificateEntry) error {
err := svdb.iterate(func(address common.Address, entry *istanbul.VersionCertificate) error {
fmt.Fprintf(&b, " [%s => %s]", address.String(), entry.String())
return nil
})
Expand All @@ -133,17 +89,17 @@ func (svdb *VersionCertificateDB) String() string {

// Upsert inserts any new entries or entries with a Version higher than the
// existing version. Returns any new or updated entries
func (svdb *VersionCertificateDB) Upsert(savEntries []*VersionCertificateEntry) ([]*VersionCertificateEntry, error) {
func (svdb *VersionCertificateDB) Upsert(savEntries []*istanbul.VersionCertificate) ([]*istanbul.VersionCertificate, error) {
logger := svdb.logger.New("func", "Upsert")

var newEntries []*VersionCertificateEntry
var newEntries []*istanbul.VersionCertificate

getExistingEntry := func(entry db.GenericEntry) (db.GenericEntry, error) {
savEntry, err := versionCertificateEntryFromGenericEntry(entry)
if err != nil {
return entry, err
}
return svdb.Get(savEntry.Address)
return svdb.Get(savEntry.Address())
}

onNewEntry := func(batch *leveldb.Batch, entry db.GenericEntry) error {
Expand All @@ -155,7 +111,7 @@ func (svdb *VersionCertificateDB) Upsert(savEntries []*VersionCertificateEntry)
if err != nil {
return err
}
batch.Put(addressKey(savEntry.Address), savEntryBytes)
batch.Put(addressKey(savEntry.Address()), savEntryBytes)
newEntries = append(newEntries, savEntry)
logger.Trace("Updating with new entry",
"address", savEntry.Address, "new version", savEntry.Version)
Expand Down Expand Up @@ -190,10 +146,10 @@ func (svdb *VersionCertificateDB) Upsert(savEntries []*VersionCertificateEntry)
return newEntries, nil
}

// Get gets the VersionCertificateEntry entry with address `address`.
// Get gets the istanbul.VersionCertificateEntry entry with address `address`.
// Returns an error if no entry exists.
func (svdb *VersionCertificateDB) Get(address common.Address) (*VersionCertificateEntry, error) {
var entry VersionCertificateEntry
func (svdb *VersionCertificateDB) Get(address common.Address) (*istanbul.VersionCertificate, error) {
var entry istanbul.VersionCertificate
entryBytes, err := svdb.gdb.Get(addressKey(address))
if err != nil {
return nil, err
Expand All @@ -214,10 +170,10 @@ func (svdb *VersionCertificateDB) GetVersion(address common.Address) (uint, erro
return signedAnnVersion.Version, nil
}

// GetAll gets each VersionCertificateEntry in the db
func (svdb *VersionCertificateDB) GetAll() ([]*VersionCertificateEntry, error) {
var entries []*VersionCertificateEntry
err := svdb.iterate(func(address common.Address, entry *VersionCertificateEntry) error {
// GetAll gets each istanbul.VersionCertificateEntry in the db
func (svdb *VersionCertificateDB) GetAll() ([]*istanbul.VersionCertificate, error) {
var entries []*istanbul.VersionCertificate
err := svdb.iterate(func(address common.Address, entry *istanbul.VersionCertificate) error {
entries = append(entries, entry)
return nil
})
Expand All @@ -237,7 +193,7 @@ func (svdb *VersionCertificateDB) Remove(address common.Address) error {
// Prune will remove entries for all addresses not present in addressesToKeep
func (svdb *VersionCertificateDB) Prune(addressesToKeep map[common.Address]bool) error {
batch := new(leveldb.Batch)
err := svdb.iterate(func(address common.Address, entry *VersionCertificateEntry) error {
err := svdb.iterate(func(address common.Address, entry *istanbul.VersionCertificate) error {
if !addressesToKeep[address] {
svdb.logger.Trace("Deleting entry", "address", address)
batch.Delete(addressKey(address))
Expand All @@ -251,13 +207,13 @@ func (svdb *VersionCertificateDB) Prune(addressesToKeep map[common.Address]bool)
}

// iterate will call `onEntry` for each entry in the db
func (svdb *VersionCertificateDB) iterate(onEntry func(common.Address, *VersionCertificateEntry) error) error {
func (svdb *VersionCertificateDB) iterate(onEntry func(common.Address, *istanbul.VersionCertificate) error) error {
logger := svdb.logger.New("func", "iterate")
// Only target address keys
keyPrefix := []byte(dbAddressPrefix)

onDBEntry := func(key []byte, value []byte) error {
var entry VersionCertificateEntry
var entry istanbul.VersionCertificate
if err := rlp.DecodeBytes(value, &entry); err != nil {
return err
}
Expand Down Expand Up @@ -285,9 +241,9 @@ type VersionCertificateEntryInfo struct {
// Intended for RPC use
func (svdb *VersionCertificateDB) Info() (map[string]*VersionCertificateEntryInfo, error) {
dbInfo := make(map[string]*VersionCertificateEntryInfo)
err := svdb.iterate(func(address common.Address, entry *VersionCertificateEntry) error {
err := svdb.iterate(func(address common.Address, entry *istanbul.VersionCertificate) error {
dbInfo[address.Hex()] = &VersionCertificateEntryInfo{
Address: entry.Address.Hex(),
Address: entry.Address().Hex(),
Version: entry.Version,
}
return nil
Expand Down
Loading

0 comments on commit 968e504

Please sign in to comment.