diff --git a/consensus/istanbul/backend/backend_test.go b/consensus/istanbul/backend/backend_test.go index 7afc168e58..31a8c9fe78 100644 --- a/consensus/istanbul/backend/backend_test.go +++ b/consensus/istanbul/backend/backend_test.go @@ -1,22 +1,28 @@ -// Copyright 2019 The klaytn Authors -// This file is part of the klaytn library. +// Modifications Copyright 2020 The klaytn Authors +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. // -// The klaytn library is free software: you can redistribute it and/or modify +// The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // -// The klaytn library is distributed in the hope that it will be useful, +// The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License -// along with the klaytn library. If not, see . +// along with the go-ethereum library. If not, see . +// +// This file is derived from quorum/consensus/istanbul/backend/backend_test.go (2020/04/16). +// Modified and improved for the klaytn development. package backend import ( + "bytes" + "crypto/ecdsa" "fmt" "github.com/hashicorp/golang-lru" "github.com/klaytn/klaytn/blockchain/types" @@ -31,10 +37,14 @@ import ( "github.com/klaytn/klaytn/reward" "github.com/klaytn/klaytn/storage/database" "math/big" + "sort" + "strings" "testing" + "time" ) var ( + testSigningData = []byte("dummy data") // testing node's private key PRIVKEY = "ce7671a2880493dfb8d04218707a16b1532dfcac97f0289d770a919d5ff7b068" // Max blockNum @@ -434,6 +444,20 @@ var ( } ) +type keys []*ecdsa.PrivateKey + +func (slice keys) Len() int { + return len(slice) +} + +func (slice keys) Less(i, j int) bool { + return strings.Compare(crypto.PubkeyToAddress(slice[i].PublicKey).String(), crypto.PubkeyToAddress(slice[j].PublicKey).String()) < 0 +} + +func (slice keys) Swap(i, j int) { + slice[i], slice[j] = slice[j], slice[i] +} + type Pair struct { Sequence int64 Round int64 @@ -838,3 +862,195 @@ func checkInCommitteeBlocks(seq int64, round int64) bool { } return false } + +func newTestBackend() (b *backend) { + dbm := database.NewDBManager(&database.DBConfig{DBType: database.MemoryDB}) + key, _ := crypto.GenerateKey() + istanbul.DefaultConfig.ProposerPolicy = istanbul.WeightedRandom + + backend := New(getTestRewards()[0], istanbul.DefaultConfig, key, dbm, getGovernance(dbm), common.CONSENSUSNODE).(*backend) + return backend +} + +func newTestValidatorSet(n int, policy istanbul.ProposerPolicy) (istanbul.ValidatorSet, []*ecdsa.PrivateKey) { + // generate validators + keys := make(keys, n) + addrs := make([]common.Address, n) + for i := 0; i < n; i++ { + privateKey, _ := crypto.GenerateKey() + keys[i] = privateKey + addrs[i] = crypto.PubkeyToAddress(privateKey.PublicKey) + } + vset := validator.NewSet(addrs, policy) + sort.Sort(keys) //Keys need to be sorted by its public key address + return vset, keys +} + +func TestSign(t *testing.T) { + b := newTestBackend() + + sig, err := b.Sign(testSigningData) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + + //Check signature recover + hashData := crypto.Keccak256([]byte(testSigningData)) + pubkey, _ := crypto.Ecrecover(hashData, sig) + actualSigner := common.BytesToAddress(crypto.Keccak256(pubkey[1:])[12:]) + + if actualSigner != b.address { + t.Errorf("address mismatch: have %v, want %s", actualSigner.Hex(), b.address.Hex()) + } +} + +func TestCheckSignature(t *testing.T) { + b := newTestBackend() + + // testAddr is derived from testPrivateKey. + testPrivateKey, _ := crypto.HexToECDSA("bb047e5940b6d83354d9432db7c449ac8fca2248008aaa7271369880f9f11cc1") + testAddr := common.HexToAddress("0x70524d664ffe731100208a0154e556f9bb679ae6") + testInvalidAddr := common.HexToAddress("0x9535b2e7faaba5288511d89341d94a38063a349b") + + hashData := crypto.Keccak256([]byte(testSigningData)) + sig, err := crypto.Sign(hashData, testPrivateKey) + if err != nil { + t.Fatalf("unexpected failure: %v", err) + } + + if err := b.CheckSignature(testSigningData, testAddr, sig); err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + + if err := b.CheckSignature(testSigningData, testInvalidAddr, sig); err != errInvalidSignature { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidSignature) + } +} + +func TestCheckValidatorSignature(t *testing.T) { + vset, keys := newTestValidatorSet(5, istanbul.WeightedRandom) + + // 1. Positive test: sign with validator's key should succeed + hashData := crypto.Keccak256([]byte(testSigningData)) + for i, k := range keys { + // Sign + sig, err := crypto.Sign(hashData, k) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + // CheckValidatorSignature should succeed + addr, err := istanbul.CheckValidatorSignature(vset, testSigningData, sig) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + validator := vset.GetByIndex(uint64(i)) + if addr != validator.Address() { + t.Errorf("validator address mismatch: have %v, want %v", addr, validator.Address()) + } + } + + // 2. Negative test: sign with any key other than validator's key should return error + key, err := crypto.GenerateKey() + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + // Sign + sig, err := crypto.Sign(hashData, key) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + // CheckValidatorSignature should return ErrUnauthorizedAddress + addr, err := istanbul.CheckValidatorSignature(vset, testSigningData, sig) + if err != istanbul.ErrUnauthorizedAddress { + t.Errorf("error mismatch: have %v, want %v", err, istanbul.ErrUnauthorizedAddress) + } + emptyAddr := common.Address{} + if addr != emptyAddr { + t.Errorf("address mismatch: have %v, want %v", addr, emptyAddr) + } +} + +func TestCommit(t *testing.T) { + backend := newTestBackend() + + commitCh := make(chan *types.Block) + // Case: it's a proposer, so the backend.commit will receive channel result from backend.Commit function + testCases := []struct { + expectedErr error + expectedSignature [][]byte + expectedBlock func() *types.Block + }{ + { + // normal case + nil, + [][]byte{append([]byte{1}, bytes.Repeat([]byte{0x00}, types.IstanbulExtraSeal-1)...)}, + func() *types.Block { + chain, engine := newBlockChain(1) + defer engine.Stop() + + block := makeBlockWithoutSeal(chain, engine, chain.Genesis()) + expectedBlock, _ := engine.updateBlock(engine.chain.GetHeader(block.ParentHash(), block.NumberU64()-1), block) + return expectedBlock + }, + }, + { + // invalid signature + errInvalidCommittedSeals, + nil, + func() *types.Block { + chain, engine := newBlockChain(1) + defer engine.Stop() + + block := makeBlockWithoutSeal(chain, engine, chain.Genesis()) + expectedBlock, _ := engine.updateBlock(engine.chain.GetHeader(block.ParentHash(), block.NumberU64()-1), block) + return expectedBlock + }, + }, + } + + for _, test := range testCases { + expBlock := test.expectedBlock() + go func() { + select { + case result := <-backend.commitCh: + commitCh <- result.Block + return + } + }() + + backend.proposedBlockHash = expBlock.Hash() + if err := backend.Commit(expBlock, test.expectedSignature); err != nil { + if err != test.expectedErr { + t.Errorf("error mismatch: have %v, want %v", err, test.expectedErr) + } + } + + if test.expectedErr == nil { + // to avoid race condition is occurred by goroutine + select { + case result := <-commitCh: + if result.Hash() != expBlock.Hash() { + t.Errorf("hash mismatch: have %v, want %v", result.Hash(), expBlock.Hash()) + } + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + } + } +} + +func TestGetProposer(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + block := makeBlock(chain, engine, chain.Genesis()) + _, err := chain.InsertChain(types.Blocks{block}) + if err != nil { + t.Errorf("failed to insert chain: %v", err) + } + expected := engine.GetProposer(1) + actual := engine.Address() + if actual != expected { + t.Errorf("proposer mismatch: have %v, want %v", actual.Hex(), expected.Hex()) + } +} diff --git a/consensus/istanbul/backend/engine_test.go b/consensus/istanbul/backend/engine_test.go new file mode 100644 index 0000000000..4aed5c86cb --- /dev/null +++ b/consensus/istanbul/backend/engine_test.go @@ -0,0 +1,507 @@ +// Modifications Copyright 2020 The klaytn Authors +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . +// +// This file is derived from quorum/consensus/istanbul/backend/engine_test.go (2020/04/16). +// Modified and improved for the klaytn development. + +package backend + +import ( + "bytes" + "crypto/ecdsa" + "github.com/klaytn/klaytn/blockchain" + "github.com/klaytn/klaytn/blockchain/types" + "github.com/klaytn/klaytn/blockchain/vm" + "github.com/klaytn/klaytn/common" + "github.com/klaytn/klaytn/common/hexutil" + "github.com/klaytn/klaytn/consensus" + "github.com/klaytn/klaytn/consensus/istanbul" + "github.com/klaytn/klaytn/crypto" + "github.com/klaytn/klaytn/params" + "github.com/klaytn/klaytn/ser/rlp" + "math/big" + "reflect" + "testing" + "time" +) + +// in this test, we can set n to 1, and it means we can process Istanbul and commit a +// block by one node. Otherwise, if n is larger than 1, we have to generate +// other fake events to process Istanbul. +func newBlockChain(n int) (*blockchain.BlockChain, *backend) { + var nodeKeys = make([]*ecdsa.PrivateKey, n) + var addrs = make([]common.Address, n) + + b := newTestBackend() + + nodeKeys[0] = b.privateKey + addrs[0] = b.address + for i := 1; i < n; i++ { + nodeKeys[i], _ = crypto.GenerateKey() + addrs[i] = crypto.PubkeyToAddress(nodeKeys[i].PublicKey) + } + + // generate a genesis block + genesis := blockchain.DefaultGenesisBlock() + genesis.Config = params.TestChainConfig + genesis.Timestamp = uint64(time.Now().Unix()) + + // force enable Istanbul engine + genesis.Config.Istanbul = ¶ms.IstanbulConfig{} + appendValidators(genesis, addrs) + + genesis.MustCommit(b.db) + + bc, err := blockchain.NewBlockChain(b.db, nil, genesis.Config, b, vm.Config{}) + if err != nil { + panic(err) + } + if b.Start(bc, bc.CurrentBlock, bc.HasBadBlock) != nil { + panic(err) + } + + return bc, b +} + +func appendValidators(genesis *blockchain.Genesis, addrs []common.Address) { + if len(genesis.ExtraData) < types.IstanbulExtraVanity { + genesis.ExtraData = append(genesis.ExtraData, bytes.Repeat([]byte{0x00}, types.IstanbulExtraVanity)...) + } + genesis.ExtraData = genesis.ExtraData[:types.IstanbulExtraVanity] + + ist := &types.IstanbulExtra{ + Validators: addrs, + Seal: []byte{}, + CommittedSeal: [][]byte{}, + } + + istPayload, err := rlp.EncodeToBytes(&ist) + if err != nil { + panic("failed to encode istanbul extra") + } + genesis.ExtraData = append(genesis.ExtraData, istPayload...) +} + +func makeHeader(parent *types.Block, config *istanbul.Config) *types.Header { + header := &types.Header{ + ParentHash: parent.Hash(), + Number: parent.Number().Add(parent.Number(), common.Big1), + GasUsed: 0, + Extra: parent.Extra(), + Time: new(big.Int).Add(parent.Time(), new(big.Int).SetUint64(config.BlockPeriod)), + BlockScore: defaultBlockScore, + } + return header +} + +func makeBlock(chain *blockchain.BlockChain, engine *backend, parent *types.Block) *types.Block { + block := makeBlockWithoutSeal(chain, engine, parent) + stopCh := make(chan struct{}) + result, err := engine.Seal(chain, block, stopCh) + if err != nil { + panic(err) + } + return result +} + +func makeBlockWithoutSeal(chain *blockchain.BlockChain, engine *backend, parent *types.Block) *types.Block { + header := makeHeader(parent, engine.config) + if err := engine.Prepare(chain, header); err != nil { + panic(err) + } + state, _ := chain.StateAt(parent.Root()) + block, _ := engine.Finalize(chain, header, state, nil, nil) + return block +} + +func TestPrepare(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + header := makeHeader(chain.Genesis(), engine.config) + err := engine.Prepare(chain, header) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + + header.ParentHash = common.HexToHash("0x1234567890") + err = engine.Prepare(chain, header) + if err != consensus.ErrUnknownAncestor { + t.Errorf("error mismatch: have %v, want %v", err, consensus.ErrUnknownAncestor) + } +} + +func TestSealStopChannel(t *testing.T) { + chain, engine := newBlockChain(4) + defer engine.Stop() + + block := makeBlockWithoutSeal(chain, engine, chain.Genesis()) + stop := make(chan struct{}, 1) + eventSub := engine.EventMux().Subscribe(istanbul.RequestEvent{}) + eventLoop := func() { + select { + case ev := <-eventSub.Chan(): + _, ok := ev.Data.(istanbul.RequestEvent) + if !ok { + t.Errorf("unexpected event comes: %v", reflect.TypeOf(ev.Data)) + } + stop <- struct{}{} + } + eventSub.Unsubscribe() + } + go eventLoop() + + finalBlock, err := engine.Seal(chain, block, stop) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + + if finalBlock != nil { + t.Errorf("block mismatch: have %v, want nil", finalBlock) + } +} + +func TestSealCommitted(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + block := makeBlockWithoutSeal(chain, engine, chain.Genesis()) + expectedBlock, _ := engine.updateBlock(engine.chain.GetHeader(block.ParentHash(), block.NumberU64()-1), block) + + actualBlock, err := engine.Seal(chain, block, make(chan struct{})) + if err != nil { + t.Errorf("error mismatch: have %v, want %v", err, expectedBlock) + } + + if actualBlock.Hash() != expectedBlock.Hash() { + t.Errorf("hash mismatch: have %v, want %v", actualBlock.Hash(), expectedBlock.Hash()) + } +} + +func TestVerifyHeader(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + // errEmptyCommittedSeals case + block := makeBlockWithoutSeal(chain, engine, chain.Genesis()) + block, _ = engine.updateBlock(chain.Genesis().Header(), block) + err := engine.VerifyHeader(chain, block.Header(), false) + if err != errEmptyCommittedSeals { + t.Errorf("error mismatch: have %v, want %v", err, errEmptyCommittedSeals) + } + + // short extra data + header := block.Header() + header.Extra = []byte{} + err = engine.VerifyHeader(chain, header, false) + if err != errInvalidExtraDataFormat { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidExtraDataFormat) + } + // incorrect extra format + header.Extra = []byte("0000000000000000000000000000000012300000000000000000000000000000000000000000000000000000000000000000") + err = engine.VerifyHeader(chain, header, false) + if err != errInvalidExtraDataFormat { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidExtraDataFormat) + } + + // invalid difficulty + block = makeBlockWithoutSeal(chain, engine, chain.Genesis()) + header = block.Header() + header.BlockScore = big.NewInt(2) + err = engine.VerifyHeader(chain, header, false) + if err != errInvalidBlockScore { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidBlockScore) + } + + // invalid timestamp + block = makeBlockWithoutSeal(chain, engine, chain.Genesis()) + header = block.Header() + header.Time = new(big.Int).Add(chain.Genesis().Time(), new(big.Int).SetUint64(engine.config.BlockPeriod-1)) + err = engine.VerifyHeader(chain, header, false) + if err != errInvalidTimestamp { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidTimestamp) + } + + // future block + block = makeBlockWithoutSeal(chain, engine, chain.Genesis()) + header = block.Header() + header.Time = new(big.Int).Add(big.NewInt(now().Unix()), new(big.Int).SetUint64(10)) + err = engine.VerifyHeader(chain, header, false) + if err != consensus.ErrFutureBlock { + t.Errorf("error mismatch: have %v, want %v", err, consensus.ErrFutureBlock) + } + + // TODO-Klaytn: add more tests for header.Governance, header.Rewardbase, header.Vote +} + +func TestVerifySeal(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + genesis := chain.Genesis() + + // cannot verify genesis + err := engine.VerifySeal(chain, genesis.Header()) + if err != errUnknownBlock { + t.Errorf("error mismatch: have %v, want %v", err, errUnknownBlock) + } + block := makeBlock(chain, engine, genesis) + + // change block content + header := block.Header() + header.Number = big.NewInt(4) + block1 := block.WithSeal(header) + err = engine.VerifySeal(chain, block1.Header()) + if err != errUnauthorized { + t.Errorf("error mismatch: have %v, want %v", err, errUnauthorized) + } + + // unauthorized users but still can get correct signer address + engine.privateKey, _ = crypto.GenerateKey() + err = engine.VerifySeal(chain, block.Header()) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } +} + +func TestVerifyHeaders(t *testing.T) { + chain, engine := newBlockChain(1) + defer engine.Stop() + + genesis := chain.Genesis() + + // success case + headers := []*types.Header{} + blocks := []*types.Block{} + size := 100 + + for i := 0; i < size; i++ { + var b *types.Block + if i == 0 { + b = makeBlockWithoutSeal(chain, engine, genesis) + b, _ = engine.updateBlock(genesis.Header(), b) + engine.db.WriteHeader(b.Header()) + } else { + b = makeBlockWithoutSeal(chain, engine, blocks[i-1]) + b, _ = engine.updateBlock(blocks[i-1].Header(), b) + engine.db.WriteHeader(b.Header()) + } + blocks = append(blocks, b) + headers = append(headers, blocks[i].Header()) + } + + // proceed time to avoid future block errors + now = func() time.Time { + return time.Unix(headers[size-1].Time.Int64(), 0) + } + defer func() { + now = time.Now + }() + + _, results := engine.VerifyHeaders(chain, headers, nil) + const timeoutDura = 2 * time.Second + timeout := time.NewTimer(timeoutDura) + index := 0 +OUT1: + for { + select { + case err := <-results: + if err != nil { + if err != errEmptyCommittedSeals && err != errInvalidCommittedSeals { + t.Errorf("error mismatch: have %v, want errEmptyCommittedSeals|errInvalidCommittedSeals", err) + break OUT1 + } + } + index++ + if index == size { + break OUT1 + } + case <-timeout.C: + break OUT1 + } + } + // abort cases + abort, results := engine.VerifyHeaders(chain, headers, nil) + timeout = time.NewTimer(timeoutDura) + index = 0 +OUT2: + for { + select { + case err := <-results: + if err != nil { + if err != errEmptyCommittedSeals && err != errInvalidCommittedSeals { + t.Errorf("error mismatch: have %v, want errEmptyCommittedSeals|errInvalidCommittedSeals", err) + break OUT2 + } + } + index++ + if index == 5 { + abort <- struct{}{} + } + if index >= size { + t.Errorf("verifyheaders should be aborted") + break OUT2 + } + case <-timeout.C: + break OUT2 + } + } + // error header cases + headers[2].Number = big.NewInt(100) + abort, results = engine.VerifyHeaders(chain, headers, nil) + timeout = time.NewTimer(timeoutDura) + index = 0 + errors := 0 + expectedErrors := 2 +OUT3: + for { + select { + case err := <-results: + if err != nil { + if err != errEmptyCommittedSeals && err != errInvalidCommittedSeals { + errors++ + } + } + index++ + if index == size { + if errors != expectedErrors { + t.Errorf("error mismatch: have %v, want %v", err, expectedErrors) + } + break OUT3 + } + case <-timeout.C: + break OUT3 + } + } +} + +func TestPrepareExtra(t *testing.T) { + validators := make([]common.Address, 4) + validators[0] = common.BytesToAddress(hexutil.MustDecode("0x44add0ec310f115a0e603b2d7db9f067778eaf8a")) + validators[1] = common.BytesToAddress(hexutil.MustDecode("0x294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212")) + validators[2] = common.BytesToAddress(hexutil.MustDecode("0x6beaaed781d2d2ab6350f5c4566a2c6eaac407a6")) + validators[3] = common.BytesToAddress(hexutil.MustDecode("0x8be76812f765c24641ec63dc2852b378aba2b440")) + + vanity := make([]byte, types.IstanbulExtraVanity) + expectedResult := append(vanity, hexutil.MustDecode("0xf858f8549444add0ec310f115a0e603b2d7db9f067778eaf8a94294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212946beaaed781d2d2ab6350f5c4566a2c6eaac407a6948be76812f765c24641ec63dc2852b378aba2b44080c0")...) + + h := &types.Header{ + Extra: vanity, + } + + payload, err := prepareExtra(h, validators) + if err != nil { + t.Errorf("error mismatch: have %v, want: nil", err) + } + if !reflect.DeepEqual(payload, expectedResult) { + t.Errorf("payload mismatch: have %v, want %v", payload, expectedResult) + } + + // append useless information to extra-data + h.Extra = append(vanity, make([]byte, 15)...) + + payload, err = prepareExtra(h, validators) + if !reflect.DeepEqual(payload, expectedResult) { + t.Errorf("payload mismatch: have %v, want %v", payload, expectedResult) + } +} + +func TestWriteSeal(t *testing.T) { + vanity := bytes.Repeat([]byte{0x00}, types.IstanbulExtraVanity) + istRawData := hexutil.MustDecode("0xf858f8549444add0ec310f115a0e603b2d7db9f067778eaf8a94294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212946beaaed781d2d2ab6350f5c4566a2c6eaac407a6948be76812f765c24641ec63dc2852b378aba2b44080c0") + expectedSeal := append([]byte{1, 2, 3}, bytes.Repeat([]byte{0x00}, types.IstanbulExtraSeal-3)...) + expectedIstExtra := &types.IstanbulExtra{ + Validators: []common.Address{ + common.BytesToAddress(hexutil.MustDecode("0x44add0ec310f115a0e603b2d7db9f067778eaf8a")), + common.BytesToAddress(hexutil.MustDecode("0x294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212")), + common.BytesToAddress(hexutil.MustDecode("0x6beaaed781d2d2ab6350f5c4566a2c6eaac407a6")), + common.BytesToAddress(hexutil.MustDecode("0x8be76812f765c24641ec63dc2852b378aba2b440")), + }, + Seal: expectedSeal, + CommittedSeal: [][]byte{}, + } + var expectedErr error + + h := &types.Header{ + Extra: append(vanity, istRawData...), + } + + // normal case + err := writeSeal(h, expectedSeal) + if err != expectedErr { + t.Errorf("error mismatch: have %v, want %v", err, expectedErr) + } + + // verify istanbul extra-data + istExtra, err := types.ExtractIstanbulExtra(h) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + if !reflect.DeepEqual(istExtra, expectedIstExtra) { + t.Errorf("extra data mismatch: have %v, want %v", istExtra, expectedIstExtra) + } + + // invalid seal + unexpectedSeal := append(expectedSeal, make([]byte, 1)...) + err = writeSeal(h, unexpectedSeal) + if err != errInvalidSignature { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidSignature) + } +} + +func TestWriteCommittedSeals(t *testing.T) { + vanity := bytes.Repeat([]byte{0x00}, types.IstanbulExtraVanity) + istRawData := hexutil.MustDecode("0xf858f8549444add0ec310f115a0e603b2d7db9f067778eaf8a94294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212946beaaed781d2d2ab6350f5c4566a2c6eaac407a6948be76812f765c24641ec63dc2852b378aba2b44080c0") + expectedCommittedSeal := append([]byte{1, 2, 3}, bytes.Repeat([]byte{0x00}, types.IstanbulExtraSeal-3)...) + expectedIstExtra := &types.IstanbulExtra{ + Validators: []common.Address{ + common.BytesToAddress(hexutil.MustDecode("0x44add0ec310f115a0e603b2d7db9f067778eaf8a")), + common.BytesToAddress(hexutil.MustDecode("0x294fc7e8f22b3bcdcf955dd7ff3ba2ed833f8212")), + common.BytesToAddress(hexutil.MustDecode("0x6beaaed781d2d2ab6350f5c4566a2c6eaac407a6")), + common.BytesToAddress(hexutil.MustDecode("0x8be76812f765c24641ec63dc2852b378aba2b440")), + }, + Seal: []byte{}, + CommittedSeal: [][]byte{expectedCommittedSeal}, + } + var expectedErr error + + h := &types.Header{ + Extra: append(vanity, istRawData...), + } + + // normal case + err := writeCommittedSeals(h, [][]byte{expectedCommittedSeal}) + if err != expectedErr { + t.Errorf("error mismatch: have %v, want %v", err, expectedErr) + } + + // verify istanbul extra-data + istExtra, err := types.ExtractIstanbulExtra(h) + if err != nil { + t.Errorf("error mismatch: have %v, want nil", err) + } + if !reflect.DeepEqual(istExtra, expectedIstExtra) { + t.Errorf("extra data mismatch: have %v, want %v", istExtra, expectedIstExtra) + } + + // invalid seal + unexpectedCommittedSeal := append(expectedCommittedSeal, make([]byte, 1)...) + err = writeCommittedSeals(h, [][]byte{unexpectedCommittedSeal}) + if err != errInvalidCommittedSeals { + t.Errorf("error mismatch: have %v, want %v", err, errInvalidCommittedSeals) + } +}