From 05178a56c97ce6c079bbf6416654051a75cce900 Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Wed, 30 Oct 2019 20:21:29 +0800 Subject: [PATCH] les: split handshake les, core: introduce forkid for les4 les: update peer unit test les: check peer version in handler les: fix linter les: address comments --- core/forkid/forkid.go | 18 +- les/client_handler.go | 27 +-- les/peer.go | 210 +++++++++++++--------- les/peer_test.go | 393 ++++++++++++++---------------------------- les/protocol.go | 9 +- les/server_handler.go | 5 +- 6 files changed, 300 insertions(+), 362 deletions(-) diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index e433db44608c..08a948510cb5 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -27,7 +27,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -44,6 +44,18 @@ var ( ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update") ) +// Blockchain defines all necessary method to build a forkID. +type Blockchain interface { + // Config retrieves the chain's fork configuration. + Config() *params.ChainConfig + + // Genesis retrieves the chain's genesis block. + Genesis() *types.Block + + // CurrentHeader retrieves the current head header of the canonical chain. + CurrentHeader() *types.Header +} + // ID is a fork identifier as defined by EIP-2124. type ID struct { Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers @@ -54,7 +66,7 @@ type ID struct { type Filter func(id ID) error // NewID calculates the Ethereum fork ID from the chain config and head. -func NewID(chain *core.BlockChain) ID { +func NewID(chain Blockchain) ID { return newID( chain.Config(), chain.Genesis().Hash(), @@ -85,7 +97,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { // NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. -func NewFilter(chain *core.BlockChain) Filter { +func NewFilter(chain Blockchain) Filter { return newFilter( chain.Config(), chain.Genesis().Hash(), diff --git a/les/client_handler.go b/les/client_handler.go index 7fdb1657194c..29f80e04a91c 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/light" @@ -35,6 +36,7 @@ import ( // responses. type clientHandler struct { ulc *ulc + forkFilter forkid.Filter checkpoint *params.TrustedCheckpoint fetcher *lightFetcher downloader *downloader.Downloader @@ -47,6 +49,7 @@ type clientHandler struct { func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ + forkFilter: forkid.NewFilter(backend.blockchain), checkpoint: checkpoint, backend: backend, closeCh: make(chan struct{}), @@ -107,7 +110,7 @@ func (h *clientHandler) handle(p *peer) error { number = head.Number.Uint64() td = h.backend.blockchain.GetTd(hash, number) ) - if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil { + if err := p.handshakeWithServer(td, hash, number, h.backend.blockchain.Genesis().Hash(), forkid.NewID(h.backend.blockchain), h.forkFilter); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } @@ -159,8 +162,8 @@ func (h *clientHandler) handleMsg(p *peer) error { var deliverMsg *Msg // Handle the message depending on its contents - switch msg.Code { - case AnnounceMsg: + switch { + case msg.Code == AnnounceMsg: p.Log().Trace("Received announce message") var req announceData if err := msg.Decode(&req); err != nil { @@ -189,7 +192,7 @@ func (h *clientHandler) handleMsg(p *peer) error { p.Log().Trace("Announce message content", "number", req.Number, "hash", req.Hash, "td", req.Td, "reorg", req.ReorgDepth) h.fetcher.announce(p, &req) } - case BlockHeadersMsg: + case msg.Code == BlockHeadersMsg: p.Log().Trace("Received block header response message") var resp struct { ReqID, BV uint64 @@ -206,7 +209,7 @@ func (h *clientHandler) handleMsg(p *peer) error { log.Debug("Failed to deliver headers", "err", err) } } - case BlockBodiesMsg: + case msg.Code == BlockBodiesMsg: p.Log().Trace("Received block bodies response") var resp struct { ReqID, BV uint64 @@ -221,7 +224,7 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case CodeMsg: + case msg.Code == CodeMsg: p.Log().Trace("Received code response") var resp struct { ReqID, BV uint64 @@ -236,7 +239,7 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case ReceiptsMsg: + case msg.Code == ReceiptsMsg: p.Log().Trace("Received receipts response") var resp struct { ReqID, BV uint64 @@ -251,7 +254,7 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Receipts, } - case ProofsV2Msg: + case msg.Code == ProofsV2Msg: p.Log().Trace("Received les/2 proofs response") var resp struct { ReqID, BV uint64 @@ -266,7 +269,7 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case HelperTrieProofsMsg: + case msg.Code == HelperTrieProofsMsg: p.Log().Trace("Received helper trie proof response") var resp struct { ReqID, BV uint64 @@ -281,7 +284,7 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case TxStatusMsg: + case msg.Code == TxStatusMsg: p.Log().Trace("Received tx status response") var resp struct { ReqID, BV uint64 @@ -296,11 +299,11 @@ func (h *clientHandler) handleMsg(p *peer) error { ReqID: resp.ReqID, Obj: resp.Status, } - case StopMsg: + case msg.Code == StopMsg && p.version >= lpv3: p.freezeServer(true) h.backend.retriever.frozen(p) p.Log().Debug("Service stopped") - case ResumeMsg: + case msg.Code == ResumeMsg && p.version >= lpv3: var bv uint64 if err := msg.Decode(&bv); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) diff --git a/les/peer.go b/les/peer.go index ab5b30a6571a..0e54fb77cc5a 100644 --- a/les/peer.go +++ b/les/peer.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" @@ -45,6 +46,11 @@ var ( errNotRegistered = errors.New("peer is not registered") ) +var ( + s = rand.NewSource(time.Now().UnixNano()) + r = rand.New(s) +) + const ( maxRequestErrors = 20 // number of invalid requests tolerated (makes the protocol less brittle but still avoids spam) maxResponseErrors = 50 // number of invalid responses tolerated (makes the protocol less brittle but still avoids spam) @@ -542,7 +548,7 @@ func (m keyValueMap) get(key string, val interface{}) error { return rlp.DecodeBytes(enc, val) } -func (p *peer) sendReceiveHandshake(sendList keyValueList) (keyValueList, error) { +func (p *peer) exchangeHandshake(sendList keyValueList) (keyValueList, error) { // Send out own handshake in a new thread errc := make(chan error, 1) go func() { @@ -570,9 +576,11 @@ func (p *peer) sendReceiveHandshake(sendList keyValueList) (keyValueList, error) return recvList, nil } -// Handshake executes the les protocol handshake, negotiating version number, -// network IDs, difficulties, head and genesis blocks. -func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { +// handshake executes the les protocol handshake, negotiating version number, +// network IDs, difficulties, head and genesis blocks. Besides the basic handshake +// fields, server and client can exchange and resolve some specified fields through +// two callback functions. +func (p *peer) handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, sendCallback func(*keyValueList), recvCallback func(keyValueMap) error) error { p.lock.Lock() defer p.lock.Unlock() @@ -585,54 +593,19 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis send = send.add("headHash", head) send = send.add("headNum", headNum) send = send.add("genesisHash", genesis) - if server != nil { - // Add some information which services server can offer. - if !server.config.UltraLightOnlyAnnounce { - send = send.add("serveHeaders", nil) - send = send.add("serveChainSince", uint64(0)) - send = send.add("serveStateSince", uint64(0)) - - // If local ethereum node is running in archive mode, advertise ourselves we have - // all version state data. Otherwise only recent state is available. - stateRecent := uint64(core.TriesInMemory - 4) - if server.archiveMode { - stateRecent = 0 - } - send = send.add("serveRecentState", stateRecent) - send = send.add("txRelay", nil) - } - send = send.add("flowControl/BL", server.defParams.BufLimit) - send = send.add("flowControl/MRR", server.defParams.MinRecharge) - var costList RequestCostList - if server.costTracker.testCostList != nil { - costList = server.costTracker.testCostList - } else { - costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) - } - send = send.add("flowControl/MRC", costList) - p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) - p.fcParams = server.defParams - - // Add advertised checkpoint and register block height which - // client can verify the checkpoint validity. - if server.oracle != nil && server.oracle.isRunning() { - cp, height := server.oracle.stableCheckpoint() - if cp != nil { - send = send.add("checkpoint/value", cp) - send = send.add("checkpoint/registerHeight", height) - } - } - } else { - // Add some client-specific handshake fields - p.announceType = announceTypeSimple - if p.trusted { - p.announceType = announceTypeSigned - } - send = send.add("announceType", p.announceType) + // If the protocol version is beyond les4, then pass the forkID + // as well. Check http://eips.ethereum.org/EIPS/eip-2124 for more + // spec detail. + if p.version >= lpv4 { + send = send.add("forkID", forkID) } - - recvList, err := p.sendReceiveHandshake(send) + // Add client-specified or server-specified fields + if sendCallback != nil { + sendCallback(&send) + } + // Exchange the handshake packet and resolve the received one. + recvList, err := p.exchangeHandshake(send) if err != nil { return err } @@ -640,47 +613,73 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis if p.rejectUpdate(size) { return errResp(ErrRequestRejected, "") } - - var rGenesis, rHash common.Hash - var rVersion, rNetwork, rNum uint64 - var rTd *big.Int - - if err := recv.get("protocolVersion", &rVersion); err != nil { + // Check and compare the protocol version of remote peer + var remoteVersion uint64 + if err := recv.get("protocolVersion", &remoteVersion); err != nil { return err } - if err := recv.get("networkId", &rNetwork); err != nil { - return err + if int(remoteVersion) != p.version { + return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", remoteVersion, p.version) } - if err := recv.get("headTd", &rTd); err != nil { + // Check and compare the network id of remote peer + var remoteNetwork uint64 + if err := recv.get("networkId", &remoteNetwork); err != nil { return err } - if err := recv.get("headHash", &rHash); err != nil { - return err + if remoteNetwork != p.network { + return errResp(ErrNetworkIdMismatch, "%d (!= %d)", remoteNetwork, p.network) } - if err := recv.get("headNum", &rNum); err != nil { + // Check and compare the genesis of remote peer + var remoteGenesis common.Hash + if err := recv.get("genesisHash", &remoteGenesis); err != nil { return err } - if err := recv.get("genesisHash", &rGenesis); err != nil { + if remoteGenesis != genesis { + return errResp(ErrGenesisBlockMismatch, "%x (!= %x)", remoteGenesis[:8], genesis[:8]) + } + // Check forkID if the protocol version is beyond the les4 + if p.version >= lpv4 { + var forkID forkid.ID + if err := recv.get("forkID", &forkID); err != nil { + return err + } + if err := forkFilter(forkID); err != nil { + return errResp(ErrForkIDRejected, "%v", err) + } + } + // Pass all checks, extract the remaning fields. + var remoteId *big.Int + if err := recv.get("headTd", &remoteId); err != nil { return err } - - if rGenesis != genesis { - return errResp(ErrGenesisBlockMismatch, "%x (!= %x)", rGenesis[:8], genesis[:8]) + var remoteHead common.Hash + if err := recv.get("headHash", &remoteHead); err != nil { + return err } - if rNetwork != p.network { - return errResp(ErrNetworkIdMismatch, "%d (!= %d)", rNetwork, p.network) + var remoteHeadNum uint64 + if err := recv.get("headNum", &remoteHeadNum); err != nil { + return err } - if int(rVersion) != p.version { - return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version) + p.headInfo = &announceData{Hash: remoteHead, Number: remoteHeadNum, Td: remoteId} + if recvCallback != nil { + return recvCallback(recv) } + return nil +} - if server != nil { - if recv.get("announceType", &p.announceType) != nil { - // set default announceType on server side - p.announceType = announceTypeSimple +// handshakeWithServer executes the les protocol handshake with les server, negotiating +// version number, network IDs, difficulties, head and genesis blocks. +func (p *peer) handshakeWithServer(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error { + return p.handshake(td, head, headNum, genesis, forkID, forkFilter, func(lists *keyValueList) { + // Add some client-specific handshake fields + // + // Enable signed announcement randomly even the server is not trusted. + p.announceType = announceTypeSimple + if p.trusted || r.Intn(10) > 3 { + p.announceType = announceTypeSigned } - p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) - } else { + *lists = (*lists).add("announceType", p.announceType) + }, func(recv keyValueMap) error { if recv.get("serveChainSince", &p.chainSince) != nil { p.onlyAnnounce = true } @@ -696,11 +695,10 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis if recv.get("txRelay", nil) != nil { p.onlyAnnounce = true } - if p.onlyAnnounce && !p.trusted { return errResp(ErrUselessPeer, "peer cannot serve requests") } - + // Parse flow control handshake packet. var sParams flowcontrol.ServerParams if err := recv.get("flowControl/BL", &sParams.BufLimit); err != nil { return err @@ -726,9 +724,59 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis } } } - } - p.headInfo = &announceData{Td: rTd, Hash: rHash, Number: rNum} - return nil + return nil + }) +} + +// handshakeWithClient executes the les protocol handshake with les client, negotiating +// version number, network IDs, difficulties, head and genesis blocks. +func (p *peer) handshakeWithClient(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, server *LesServer) error { + return p.handshake(td, head, headNum, genesis, forkID, forkFilter, func(lists *keyValueList) { + // Add some information which services server can offer. + if !server.config.UltraLightOnlyAnnounce { + *lists = (*lists).add("serveHeaders", nil) + *lists = (*lists).add("serveChainSince", uint64(0)) + *lists = (*lists).add("serveStateSince", uint64(0)) + + // If local ethereum node is running in archive mode, advertise ourselves we have + // all version state data. Otherwise only recent state is available. + stateRecent := uint64(core.TriesInMemory - 4) + if server.archiveMode { + stateRecent = 0 + } + *lists = (*lists).add("serveRecentState", stateRecent) + *lists = (*lists).add("txRelay", nil) + } + *lists = (*lists).add("flowControl/BL", server.defParams.BufLimit) + *lists = (*lists).add("flowControl/MRR", server.defParams.MinRecharge) + + var costList RequestCostList + if server.costTracker.testCostList != nil { + costList = server.costTracker.testCostList + } else { + costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) + } + *lists = (*lists).add("flowControl/MRC", costList) + p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) + p.fcParams = server.defParams + + // Add advertised checkpoint and register block height which + // client can verify the checkpoint validity. + if server.oracle != nil && server.oracle.isRunning() { + cp, height := server.oracle.stableCheckpoint() + if cp != nil { + *lists = (*lists).add("checkpoint/value", cp) + *lists = (*lists).add("checkpoint/registerHeight", height) + } + } + }, func(recv keyValueMap) error { + if recv.get("announceType", &p.announceType) != nil { + // set default announceType on server side + p.announceType = announceTypeSimple + } + p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) + return nil + }) } // updateFlowControl updates the flow control parameters belonging to the server diff --git a/les/peer_test.go b/les/peer_test.go index db74a052c14c..cb3f3fc46274 100644 --- a/les/peer_test.go +++ b/les/peer_test.go @@ -17,286 +17,155 @@ package les import ( + "crypto/rand" "math/big" - "net" + "reflect" + "sort" "testing" + "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/les/flowcontrol" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/params" ) -const protocolVersion = lpv2 - -var ( - hash = common.HexToHash("deadbeef") - genesis = common.HexToHash("cafebabe") - headNum = uint64(1234) - td = big.NewInt(123) -) - -func newNodeID(t *testing.T) *enode.Node { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatal("generate key err:", err) - } - return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000) +type testServerPeerSub struct { + regCh chan *peer + unregCh chan *peer } -// ulc connects to trusted peer and send announceType=announceTypeSigned -func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) { - id := newNodeID(t).ID() - - // peer to connect(on ulc side) - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - trusted: true, - rw: &rwStub{ - WriteHook: func(recvList keyValueList) { - recv, _ := recvList.decode() - var reqType uint64 - err := recv.get("announceType", &reqType) - if err != nil { - t.Fatal(err) - } - if reqType != announceTypeSigned { - t.Fatal("Expected announceTypeSigned") - } - }, - ReadHook: func(l keyValueList) keyValueList { - l = l.add("serveHeaders", nil) - l = l.add("serveChainSince", uint64(0)) - l = l.add("serveStateSince", uint64(0)) - l = l.add("txRelay", nil) - l = l.add("flowControl/BL", uint64(0)) - l = l.add("flowControl/MRR", uint64(0)) - l = l.add("flowControl/MRC", testCostList(0)) - return l - }, - }, - network: NetworkId, - } - err := p.Handshake(td, hash, headNum, genesis, nil) - if err != nil { - t.Fatalf("Handshake error: %s", err) - } - if p.announceType != announceTypeSigned { - t.Fatal("Incorrect announceType") +func newTestServerPeerSub() *testServerPeerSub { + return &testServerPeerSub{ + regCh: make(chan *peer, 1), + unregCh: make(chan *peer, 1), } } -func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testing.T) { - id := newNodeID(t).ID() - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - rw: &rwStub{ - WriteHook: func(recvList keyValueList) { - // checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned - recv, _ := recvList.decode() - var reqType uint64 - err := recv.get("announceType", &reqType) - if err != nil { - t.Fatal(err) - } - if reqType == announceTypeSigned { - t.Fatal("Expected not announceTypeSigned") - } - }, - ReadHook: func(l keyValueList) keyValueList { - l = l.add("serveHeaders", nil) - l = l.add("serveChainSince", uint64(0)) - l = l.add("serveStateSince", uint64(0)) - l = l.add("txRelay", nil) - l = l.add("flowControl/BL", uint64(0)) - l = l.add("flowControl/MRR", uint64(0)) - l = l.add("flowControl/MRC", testCostList(0)) - return l - }, - }, - network: NetworkId, - } - err := p.Handshake(td, hash, headNum, genesis, nil) - if err != nil { - t.Fatal(err) - } - if p.announceType == announceTypeSigned { - t.Fatal("Incorrect announceType") - } +func (t *testServerPeerSub) registerPeer(p *peer) { t.regCh <- p } +func (t *testServerPeerSub) unregisterPeer(p *peer) { t.unregCh <- p } + +func TestPeerSubscription(t *testing.T) { + peers := newPeerSet() + defer peers.Close() + + checkIds := func(expect []string) { + given := peers.AllPeerIDs() + if len(given) == 0 && len(expect) == 0 { + return + } + sort.Strings(given) + sort.Strings(expect) + if !reflect.DeepEqual(given, expect) { + t.Fatalf("all peer ids mismatch, want %v, given %v", expect, given) + } + } + checkPeers := func(peerCh chan *peer) { + select { + case <-peerCh: + case <-time.NewTimer(100 * time.Millisecond).C: + t.Fatalf("timeout, no event received") + } + select { + case <-peerCh: + t.Fatalf("unexpected event received") + case <-time.NewTimer(10 * time.Millisecond).C: + } + } + checkIds([]string{}) + + sub := newTestServerPeerSub() + peers.notify(sub) + + // Generate a random id and create the peer + var id enode.ID + rand.Read(id[:]) + peer := newPeer(2, NetworkId, false, p2p.NewPeer(id, "name", nil), nil) + peers.Register(peer) + + checkIds([]string{peer.id}) + checkPeers(sub.regCh) + + peers.Unregister(peer.id) + checkIds([]string{}) + checkPeers(sub.unregCh) } -func TestPeerHandshakeDefaultAllRequests(t *testing.T) { - id := newNodeID(t).ID() +func TestHandshakeLes2(t *testing.T) { testHandshake(t, lpv2) } +func TestHandshakeLes3(t *testing.T) { testHandshake(t, lpv3) } +func TestHandshakeLes4(t *testing.T) { testHandshake(t, lpv4) } - s := generateLesServer() +type fakeChain struct{} - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - rw: &rwStub{ - ReadHook: func(l keyValueList) keyValueList { - l = l.add("announceType", uint64(announceTypeSigned)) - l = l.add("allowedRequests", uint64(0)) - return l - }, - }, - network: NetworkId, - } - - err := p.Handshake(td, hash, headNum, genesis, s) - if err != nil { - t.Fatal(err) - } - - if p.onlyAnnounce { - t.Fatal("Incorrect announceType") - } -} - -func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) { - id := newNodeID(t).ID() - - s := generateLesServer() - s.config.UltraLightOnlyAnnounce = true - - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - rw: &rwStub{ - ReadHook: func(l keyValueList) keyValueList { - l = l.add("announceType", uint64(announceTypeSigned)) - return l - }, - WriteHook: func(l keyValueList) { - for _, v := range l { - if v.Key == "serveHeaders" || - v.Key == "serveChainSince" || - v.Key == "serveStateSince" || - v.Key == "txRelay" { - t.Fatalf("%v exists", v.Key) - } - } - }, - }, - network: NetworkId, - } - - err := p.Handshake(td, hash, headNum, genesis, s) - if err != nil { - t.Fatal(err) - } +func (f *fakeChain) Config() *params.ChainConfig { return params.MainnetChainConfig } +func (f *fakeChain) Genesis() *types.Block { + return core.DefaultGenesisBlock().ToBlock(rawdb.NewMemoryDatabase()) } -func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) { - id := newNodeID(t).ID() - - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - rw: &rwStub{ - ReadHook: func(l keyValueList) keyValueList { - l = l.add("flowControl/BL", uint64(0)) - l = l.add("flowControl/MRR", uint64(0)) - l = l.add("flowControl/MRC", RequestCostList{}) - - l = l.add("announceType", uint64(announceTypeSigned)) - - return l - }, - }, - network: NetworkId, - trusted: true, - } - - err := p.Handshake(td, hash, headNum, genesis, nil) - if err != nil { - t.Fatal(err) - } - - if !p.onlyAnnounce { - t.Fatal("onlyAnnounce must be true") - } -} - -func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) { - id := newNodeID(t).ID() - - p := peer{ - Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), - version: protocolVersion, - rw: &rwStub{ - ReadHook: func(l keyValueList) keyValueList { - l = l.add("flowControl/BL", uint64(0)) - l = l.add("flowControl/MRR", uint64(0)) - l = l.add("flowControl/MRC", RequestCostList{}) - l = l.add("announceType", uint64(announceTypeSigned)) - return l - }, - }, - network: NetworkId, - } - - err := p.Handshake(td, hash, headNum, genesis, nil) - if err == nil { - t.FailNow() - } -} - -func generateLesServer() *LesServer { - s := &LesServer{ - lesCommons: lesCommons{ - config: ð.Config{UltraLightOnlyAnnounce: true}, - }, - defParams: flowcontrol.ServerParams{ - BufLimit: uint64(300000000), - MinRecharge: uint64(50000), - }, - fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), - } - s.costTracker, _ = newCostTracker(rawdb.NewMemoryDatabase(), s.config) - return s -} - -type rwStub struct { - ReadHook func(l keyValueList) keyValueList - WriteHook func(l keyValueList) -} - -func (s *rwStub) ReadMsg() (p2p.Msg, error) { - payload := keyValueList{} - payload = payload.add("protocolVersion", uint64(protocolVersion)) - payload = payload.add("networkId", uint64(NetworkId)) - payload = payload.add("headTd", td) - payload = payload.add("headHash", hash) - payload = payload.add("headNum", headNum) - payload = payload.add("genesisHash", genesis) - - if s.ReadHook != nil { - payload = s.ReadHook(payload) - } - size, p, err := rlp.EncodeToReader(payload) - if err != nil { - return p2p.Msg{}, err - } - return p2p.Msg{ - Size: uint32(size), - Payload: p, - }, nil -} - -func (s *rwStub) WriteMsg(m p2p.Msg) error { - recvList := keyValueList{} - if err := m.Decode(&recvList); err != nil { - return err - } - if s.WriteHook != nil { - s.WriteHook(recvList) +func (f *fakeChain) CurrentHeader() *types.Header { return &types.Header{Number: big.NewInt(10000000)} } + +func testHandshake(t *testing.T, protocol int) { + // Create a message pipe to communicate through + app, net := p2p.MsgPipe() + + // Generate a random id and create the peer + var id enode.ID + rand.Read(id[:]) + + peer1 := newPeer(protocol, NetworkId, false, p2p.NewPeer(id, "peer1", nil), net) + peer2 := newPeer(protocol, NetworkId, false, p2p.NewPeer(id, "peer2", nil), app) + + var ( + errCh1 = make(chan error, 1) + errCh2 = make(chan error, 1) + + td = big.NewInt(100) + head = common.HexToHash("deadbeef") + headNum = uint64(10) + genesis = common.HexToHash("cafebabe") + ) + + chain1, chain2 := &fakeChain{}, &fakeChain{} + forkID1, forkID2 := forkid.NewID(chain1), forkid.NewID(chain2) + filter1, filter2 := forkid.NewFilter(chain1), forkid.NewFilter(chain2) + + go func() { + // Exchange handshake with remote server peer + errCh1 <- peer1.handshake(td, head, headNum, genesis, forkID1, filter1, func(list *keyValueList) { + var announceType uint64 = announceTypeSigned + *list = (*list).add("announceType", announceType) + }, nil) + }() + go func() { + // Exchange handshake with remote client peer + errCh2 <- peer2.handshake(td, head, headNum, genesis, forkID2, filter2, nil, func(recv keyValueMap) error { + var reqType uint64 + err := recv.get("announceType", &reqType) + if err != nil { + t.Fatal(err) + } + if reqType != announceTypeSigned { + t.Fatal("Expected announceTypeSigned") + } + return nil + }) + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errCh1: + if err != nil { + t.Fatalf("handshake failed, %v", err) + } + case err := <-errCh2: + if err != nil { + t.Fatalf("handshake failed, %v", err) + } + case <-time.NewTimer(500 * time.Millisecond).C: + t.Fatalf("timeout") + } } - return nil } diff --git a/les/protocol.go b/les/protocol.go index 36af88aea6d0..59cafcd43bbd 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -33,17 +33,18 @@ import ( const ( lpv2 = 2 lpv3 = 3 + lpv4 = 4 ) // Supported versions of the les protocol (first is primary) var ( - ClientProtocolVersions = []uint{lpv2, lpv3} - ServerProtocolVersions = []uint{lpv2, lpv3} + ClientProtocolVersions = []uint{lpv2, lpv3, lpv4} + ServerProtocolVersions = []uint{lpv2, lpv3, lpv4} AdvertiseProtocolVersions = []uint{lpv2} // clients are searching for the first advertised protocol in the list ) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24} +var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 24} const ( NetworkId = 1 @@ -110,6 +111,7 @@ const ( ErrInvalidResponse ErrTooManyTimeouts ErrMissingKey + ErrForkIDRejected ) func (e errCode) String() string { @@ -132,6 +134,7 @@ var errorToString = map[int]string{ ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", ErrMissingKey: "Key missing from list", + ErrForkIDRejected: "Forkid rejected", } type announceBlock struct { diff --git a/les/server_handler.go b/les/server_handler.go index 16249ef1ba27..620dbaf60894 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -62,6 +63,7 @@ var ( // serverHandler is responsible for serving light client and process // all incoming light requests. type serverHandler struct { + forkFilter forkid.Filter blockchain *core.BlockChain chainDb ethdb.Database txpool *core.TxPool @@ -77,6 +79,7 @@ type serverHandler struct { func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler { handler := &serverHandler{ + forkFilter: forkid.NewFilter(blockchain), server: server, blockchain: blockchain, chainDb: chainDb, @@ -121,7 +124,7 @@ func (h *serverHandler) handle(p *peer) error { number = head.Number.Uint64() td = h.blockchain.GetTd(hash, number) ) - if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil { + if err := p.handshakeWithClient(td, hash, number, h.blockchain.Genesis().Hash(), forkid.NewID(h.blockchain), h.forkFilter, h.server); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err }