From 88174caef22126fc6fa200cdf8dbf1183ea0b4c9 Mon Sep 17 00:00:00 2001 From: noot <36753753+noot@users.noreply.github.com> Date: Thu, 18 Nov 2021 11:02:39 -0500 Subject: [PATCH] fix(dot/network): fix bugs in notifications protocol handlers; add metrics for inbound/outbound streams (#2010) --- dot/network/block_announce_test.go | 2 +- dot/network/connmgr.go | 39 ++--- dot/network/discovery.go | 13 ++ dot/network/errors.go | 17 +++ dot/network/host_test.go | 2 +- dot/network/inbound.go | 76 +++++++++ dot/network/light.go | 75 +++++++++ dot/network/notifications.go | 237 +++++++++++++++++------------ dot/network/notifications_test.go | 31 ++-- dot/network/service.go | 225 +++++++++------------------ dot/network/service_test.go | 2 +- dot/network/utils.go | 4 + dot/services_test.go | 5 +- dot/sync/chain_sync.go | 3 +- dot/sync/chain_sync_test.go | 2 +- dot/types/grandpa.go | 2 +- lib/grandpa/grandpa.go | 9 +- 17 files changed, 443 insertions(+), 301 deletions(-) create mode 100644 dot/network/errors.go create mode 100644 dot/network/inbound.go diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index 0af14ab1e2..66fb6b787b 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -155,7 +155,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) { inboundHandshakeData: new(sync.Map), } testPeerID := peer.ID("noot") - nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, handshakeData{}) + nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, &handshakeData{}) err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{ BestBlockNumber: 100, diff --git a/dot/network/connmgr.go b/dot/network/connmgr.go index 37c5894f2f..9f7e257d73 100644 --- a/dot/network/connmgr.go +++ b/dot/network/connmgr.go @@ -12,7 +12,6 @@ import ( "github.com/libp2p/go-libp2p-core/connmgr" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" ma "github.com/multiformats/go-multiaddr" "github.com/ChainSafe/gossamer/dot/peerset" @@ -23,11 +22,9 @@ type ConnManager struct { sync.Mutex host *host min, max int + connectHandler func(peer.ID) disconnectHandler func(peer.ID) - // closeHandlerMap contains close handler corresponding to a protocol. - closeHandlerMap map[protocol.ID]func(peerID peer.ID) - // protectedPeers contains a list of peers that are protected from pruning // when we reach the maximum numbers of peers. protectedPeers *sync.Map // map[peer.ID]struct{} @@ -47,7 +44,6 @@ func newConnManager(min, max int, peerSetCfg *peerset.ConfigSet) (*ConnManager, return &ConnManager{ min: min, max: max, - closeHandlerMap: make(map[protocol.ID]func(peerID peer.ID)), protectedPeers: new(sync.Map), persistentPeers: new(sync.Map), peerSetHandler: psh, @@ -68,19 +64,19 @@ func (cm *ConnManager) Notifee() network.Notifiee { return nb } -// TagPeer peer +// TagPeer is unimplemented func (*ConnManager) TagPeer(peer.ID, string, int) {} -// UntagPeer peer +// UntagPeer is unimplemented func (*ConnManager) UntagPeer(peer.ID, string) {} -// UpsertTag peer +// UpsertTag is unimplemented func (*ConnManager) UpsertTag(peer.ID, string, func(int) int) {} -// GetTagInfo peer +// GetTagInfo is unimplemented func (*ConnManager) GetTagInfo(peer.ID) *connmgr.TagInfo { return &connmgr.TagInfo{} } -// TrimOpenConns peer +// TrimOpenConns is unimplemented func (*ConnManager) TrimOpenConns(context.Context) {} // Protect peer will add the given peer to the protectedPeerMap which will @@ -97,7 +93,7 @@ func (cm *ConnManager) Unprotect(id peer.ID, _ string) bool { return wasDeleted } -// Close peer +// Close is unimplemented func (*ConnManager) Close() error { return nil } // IsProtected returns whether the given peer is protected from pruning or not. @@ -134,6 +130,7 @@ func (cm *ConnManager) unprotectedPeers(peers []peer.ID) []peer.ID { func (cm *ConnManager) Connected(n network.Network, c network.Conn) { logger.Tracef( "Host %s connected to peer %s", n.LocalPeer(), c.RemotePeer()) + cm.connectHandler(c.RemotePeer()) cm.Lock() defer cm.Unlock() @@ -143,7 +140,9 @@ func (cm *ConnManager) Connected(n network.Network, c network.Conn) { return } + // TODO: peer scoring doesn't seem to prevent us from going over the max. // if over the max peer count, disconnect from (total_peers - maximum) peers + // (#2039) for i := 0; i < over; i++ { unprotPeers := cm.unprotectedPeers(n.Peers()) if len(unprotPeers) == 0 { @@ -170,31 +169,19 @@ func (cm *ConnManager) Disconnected(_ network.Network, c network.Conn) { logger.Tracef("Host %s disconnected from peer %s", c.LocalPeer(), c.RemotePeer()) cm.Unprotect(c.RemotePeer(), "") - if cm.disconnectHandler != nil { - cm.disconnectHandler(c.RemotePeer()) - } + cm.disconnectHandler(c.RemotePeer()) } -// OpenedStream is called when a stream opened +// OpenedStream is called when a stream is opened func (cm *ConnManager) OpenedStream(_ network.Network, s network.Stream) { logger.Tracef("Stream opened with peer %s using protocol %s", s.Conn().RemotePeer(), s.Protocol()) } -func (cm *ConnManager) registerCloseHandler(protocolID protocol.ID, cb func(id peer.ID)) { - cm.closeHandlerMap[protocolID] = cb -} - -// ClosedStream is called when a stream closed +// ClosedStream is called when a stream is closed func (cm *ConnManager) ClosedStream(_ network.Network, s network.Stream) { logger.Tracef("Stream closed with peer %s using protocol %s", s.Conn().RemotePeer(), s.Protocol()) - - cm.Lock() - defer cm.Unlock() - if closeCB, ok := cm.closeHandlerMap[s.Protocol()]; ok { - closeCB(s.Conn().RemotePeer()) - } } func (cm *ConnManager) isPersistent(p peer.ID) bool { diff --git a/dot/network/discovery.go b/dot/network/discovery.go index 911f88f5b4..215fc9734f 100644 --- a/dot/network/discovery.go +++ b/dot/network/discovery.go @@ -197,8 +197,21 @@ func (d *discovery) findPeers(ctx context.Context) { logger.Tracef("found new peer %s via DHT", peer.ID) + // TODO: this isn't working on the devnet (#2026) + // can remove the code block below which directly connects + // once that's fixed d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) d.handler.AddPeer(0, peer.ID) + + // found a peer, try to connect if we need more peers + if len(d.h.Network().Peers()) >= d.maxPeers { + d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) + return + } + + if err = d.h.Connect(d.ctx, peer); err != nil { + logger.Tracef("failed to connect to discovered peer %s: %s", peer.ID, err) + } } } } diff --git a/dot/network/errors.go b/dot/network/errors.go new file mode 100644 index 0000000000..a9c2bd94f3 --- /dev/null +++ b/dot/network/errors.go @@ -0,0 +1,17 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package network + +import ( + "errors" +) + +var ( + errCannotValidateHandshake = errors.New("failed to validate handshake") + errMessageTypeNotValid = errors.New("message type is not valid") + errMessageIsNotHandshake = errors.New("failed to convert message to Handshake") + errMissingHandshakeMutex = errors.New("outboundHandshakeMutex does not exist") + errInvalidHandshakeForPeer = errors.New("peer previously sent invalid handshake") + errHandshakeTimeout = errors.New("handshake timeout reached") +) diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 346ed6c42e..a2f45cac9f 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -313,7 +313,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { info := nodeA.notificationsProtocols[BlockAnnounceMsgType] // Set handshake data to received - info.inboundHandshakeData.Store(nodeB.host.id(), handshakeData{ + info.inboundHandshakeData.Store(nodeB.host.id(), &handshakeData{ received: true, validated: true, }) diff --git a/dot/network/inbound.go b/dot/network/inbound.go new file mode 100644 index 0000000000..154684ed28 --- /dev/null +++ b/dot/network/inbound.go @@ -0,0 +1,76 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package network + +import ( + libp2pnetwork "github.com/libp2p/go-libp2p-core/network" +) + +func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) { + // we NEED to reset the stream if we ever return from this function, as if we return, + // the stream will never again be read by us, so we need to tell the remote side we're + // done with this stream, and they should also forget about it. + defer s.resetInboundStream(stream) + s.streamManager.logNewStream(stream) + + peer := stream.Conn().RemotePeer() + msgBytes := s.bufPool.get() + defer s.bufPool.put(msgBytes) + + for { + n, err := readStream(stream, msgBytes[:]) + if err != nil { + logger.Tracef( + "failed to read from stream id %s of peer %s using protocol %s: %s", + stream.ID(), stream.Conn().RemotePeer(), stream.Protocol(), err) + return + } + + s.streamManager.logMessageReceived(stream.ID()) + + // decode message based on message type + msg, err := decoder(msgBytes[:n], peer, isInbound(stream)) // stream should always be inbound if it passes through service.readStream + if err != nil { + logger.Tracef("failed to decode message from stream id %s using protocol %s: %s", + stream.ID(), stream.Protocol(), err) + continue + } + + logger.Tracef( + "host %s received message from peer %s: %s", + s.host.id(), peer, msg) + + if err = handler(stream, msg); err != nil { + logger.Tracef("failed to handle message %s from stream id %s: %s", msg, stream.ID(), err) + return + } + + s.host.bwc.LogRecvMessage(int64(n)) + } +} + +func (s *Service) resetInboundStream(stream libp2pnetwork.Stream) { + protocolID := stream.Protocol() + peerID := stream.Conn().RemotePeer() + + s.notificationsMu.Lock() + defer s.notificationsMu.Unlock() + + for _, prtl := range s.notificationsProtocols { + if prtl.protocolID != protocolID { + continue + } + + prtl.inboundHandshakeData.Delete(peerID) + break + } + + logger.Debugf( + "cleaning up inbound handshake data for protocol=%s, peer=%s", + stream.Protocol(), + peerID, + ) + + _ = stream.Reset() +} diff --git a/dot/network/light.go b/dot/network/light.go index de32e03945..d782568133 100644 --- a/dot/network/light.go +++ b/dot/network/light.go @@ -9,8 +9,71 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" + + libp2pnetwork "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" ) +// handleLightStream handles streams with the /light/2 protocol ID +func (s *Service) handleLightStream(stream libp2pnetwork.Stream) { + s.readStream(stream, s.decodeLightMessage, s.handleLightMsg) +} + +func (s *Service) decodeLightMessage(in []byte, peer peer.ID, _ bool) (Message, error) { + s.lightRequestMu.RLock() + defer s.lightRequestMu.RUnlock() + + // check if we are the requester + if _, ok := s.lightRequest[peer]; ok { + // if we are, decode the bytes as a LightResponse + return newLightResponseFromBytes(in) + } + + // otherwise, decode bytes as LightRequest + return newLightRequestFromBytes(in) +} + +func (s *Service) handleLightMsg(stream libp2pnetwork.Stream, msg Message) (err error) { + defer func() { + _ = stream.Close() + }() + + lr, ok := msg.(*LightRequest) + if !ok { + return nil + } + + resp := NewLightResponse() + switch { + case lr.RemoteCallRequest != nil: + resp.RemoteCallResponse, err = remoteCallResp(lr.RemoteCallRequest) + case lr.RemoteHeaderRequest != nil: + resp.RemoteHeaderResponse, err = remoteHeaderResp(lr.RemoteHeaderRequest) + case lr.RemoteChangesRequest != nil: + resp.RemoteChangesResponse, err = remoteChangeResp(lr.RemoteChangesRequest) + case lr.RemoteReadRequest != nil: + resp.RemoteReadResponse, err = remoteReadResp(lr.RemoteReadRequest) + case lr.RemoteReadChildRequest != nil: + resp.RemoteReadResponse, err = remoteReadChildResp(lr.RemoteReadChildRequest) + default: + logger.Warn("ignoring LightRequest without request data") + return nil + } + + if err != nil { + return err + } + + // TODO(arijit): Remove once we implement the internal APIs. Added to increase code coverage. (#1856) + logger.Debugf("LightResponse message: %s", resp) + + err = s.host.writeToStream(stream, resp) + if err != nil { + logger.Warnf("failed to send LightResponse message to peer %s: %s", stream.Conn().RemotePeer(), err) + } + return err +} + // Pair is a pair of arbitrary bytes. type Pair struct { First []byte @@ -46,6 +109,12 @@ func NewLightRequest() *LightRequest { } } +func newLightRequestFromBytes(in []byte) (msg *LightRequest, err error) { + msg = NewLightRequest() + err = msg.Decode(in) + return msg, err +} + func newRequest() *request { return &request{ RemoteCallRequest: *newRemoteCallRequest(), @@ -122,6 +191,12 @@ func NewLightResponse() *LightResponse { } } +func newLightResponseFromBytes(in []byte) (msg *LightResponse, err error) { + msg = NewLightResponse() + err = msg.Decode(in) + return msg, err +} + func newResponse() *response { return &response{ RemoteCallResponse: *newRemoteCallResponse(), diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 3bf9615ccc..804388d4b0 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -5,21 +5,19 @@ package network import ( "errors" - "reflect" + "fmt" + "io" "sync" "time" "github.com/ChainSafe/gossamer/dot/peerset" + + "github.com/libp2p/go-libp2p-core/mux" libp2pnetwork "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" ) -var ( - errCannotValidateHandshake = errors.New("failed to validate handshake") - maxHandshakeSize = reflect.TypeOf(BlockAnnounceHandshake{}).Size() -) - const handshakeTimeout = time.Second * 10 // Handshake is the interface all handshakes for notifications protocols must implement @@ -62,16 +60,28 @@ type handshakeReader struct { } type notificationsProtocol struct { - protocolID protocol.ID - getHandshake HandshakeGetter - handshakeDecoder HandshakeDecoder - handshakeValidator HandshakeValidator + protocolID protocol.ID + getHandshake HandshakeGetter + handshakeDecoder HandshakeDecoder + handshakeValidator HandshakeValidator + outboundHandshakeMutexes *sync.Map //map[peer.ID]*sync.Mutex + inboundHandshakeData *sync.Map //map[peer.ID]*handshakeData + outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData +} - inboundHandshakeData *sync.Map //map[peer.ID]*handshakeData - outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData +func newNotificationsProtocol(protocolID protocol.ID, handshakeGetter HandshakeGetter, handshakeDecoder HandshakeDecoder, handshakeValidator HandshakeValidator) *notificationsProtocol { + return ¬ificationsProtocol{ + protocolID: protocolID, + getHandshake: handshakeGetter, + handshakeValidator: handshakeValidator, + handshakeDecoder: handshakeDecoder, + outboundHandshakeMutexes: new(sync.Map), + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), + } } -func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (handshakeData, bool) { +func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (*handshakeData, bool) { var ( data interface{} has bool @@ -79,13 +89,13 @@ func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (handshakeD data, has = n.inboundHandshakeData.Load(pid) if !has { - return handshakeData{}, false + return nil, false } - return data.(handshakeData), true + return data.(*handshakeData), true } -func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (handshakeData, bool) { +func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (*handshakeData, bool) { var ( data interface{} has bool @@ -93,10 +103,10 @@ func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (handshake data, has = n.outboundHandshakeData.Load(pid) if !has { - return handshakeData{}, false + return nil, false } - return data.(handshakeData), true + return data.(*handshakeData), true } type handshakeData struct { @@ -104,15 +114,13 @@ type handshakeData struct { validated bool handshake Handshake stream libp2pnetwork.Stream - *sync.Mutex } -func newHandshakeData(received, validated bool, stream libp2pnetwork.Stream) handshakeData { - return handshakeData{ +func newHandshakeData(received, validated bool, stream libp2pnetwork.Stream) *handshakeData { + return &handshakeData{ received: received, validated: validated, stream: stream, - Mutex: new(sync.Mutex), } } @@ -121,7 +129,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode // if we don't have handshake data on this peer, or we haven't received the handshake from them already, // assume we are receiving the handshake var ( - hsData handshakeData + hsData *handshakeData has bool ) @@ -140,6 +148,7 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode } } +// createNotificationsMessageHandler returns a function that is called by the handler of *inbound* streams. func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, messageHandler NotificationsMessageHandler, batchHandler NotificationsMessageBatchHandler) messageHandler { return func(stream libp2pnetwork.Stream, m Message) error { if m == nil || info == nil || info.handshakeValidator == nil || messageHandler == nil { @@ -153,7 +162,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, ) if msg, ok = m.(NotificationsMessage); !ok { - return errors.New("message is not NotificationsMessage") + return fmt.Errorf("%w: expected %T but got %T", errMessageTypeNotValid, (NotificationsMessage)(nil), msg) } if msg.IsHandshake() { @@ -162,7 +171,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, hs, ok := msg.(Handshake) if !ok { - return errors.New("failed to convert message to Handshake") + return errMessageIsNotHandshake } // if we are the receiver and haven't received the handshake already, validate it @@ -198,6 +207,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, logger.Tracef("failed to send handshake to peer %s using protocol %s: %s", peer, info.protocolID, err) return err } + logger.Tracef("receiver: sent handshake to peer %s using protocol %s", peer, info.protocolID) } @@ -224,6 +234,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, if err != nil { return err } + msgs = append(msgs, &BatchMessage{ msg: msg, peer: peer, @@ -251,7 +262,23 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, } } +func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream libp2pnetwork.Stream) { + logger.Debugf( + "cleaning up outbound handshake data for protocol=%s, peer=%s", + stream.Protocol(), + peerID, + ) + + info.outboundHandshakeData.Delete(peerID) + _ = stream.Close() +} + func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) { + if info.handshakeValidator == nil { + logger.Errorf("handshakeValidator is not set for protocol %s", info.protocolID) + return + } + if support, err := s.host.supportsProtocol(peer, info.protocolID); err != nil || !support { s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{ Value: peerset.BadProtocolValue, @@ -261,75 +288,12 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc return } - hsData, has := info.getOutboundHandshakeData(peer) - if has && !hsData.validated { - // peer has sent us an invalid handshake in the past, ignore + stream, err := s.sendHandshake(peer, hs, info) + if err != nil { + logger.Debugf("failed to send handshake to peer %s on protocol %s: %s", peer, info.protocolID, err) return } - if !has || !hsData.received || hsData.stream == nil { - if !has { - hsData = newHandshakeData(false, false, nil) - } - - hsData.Lock() - defer hsData.Unlock() - - logger.Tracef("sending outbound handshake to peer %s using protocol %s, message: %s", - peer, info.protocolID, hs) - stream, err := s.host.send(peer, info.protocolID, hs) - if err != nil { - logger.Tracef("failed to send message to peer %s: %s", peer, err) - return - } - - hsData.stream = stream - info.outboundHandshakeData.Store(peer, hsData) - - if info.handshakeValidator == nil { - return - } - - hsTimer := time.NewTimer(handshakeTimeout) - - var hs Handshake - select { - case <-hsTimer.C: - s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{ - Value: peerset.TimeOutValue, - Reason: peerset.TimeOutReason, - }, peer) - - logger.Tracef("handshake timeout reached for peer %s using protocol %s", peer, info.protocolID) - _ = stream.Close() - info.outboundHandshakeData.Delete(peer) - return - case hsResponse := <-s.readHandshake(stream, info.handshakeDecoder): - hsTimer.Stop() - if hsResponse.err != nil { - logger.Tracef("failed to read handshake from peer %s using protocol %s: %s", peer, info.protocolID, err) - _ = stream.Close() - info.outboundHandshakeData.Delete(peer) - return - } - - hs = hsResponse.hs - hsData.received = true - } - - err = info.handshakeValidator(peer, hs) - if err != nil { - logger.Tracef("failed to validate handshake from peer %s using protocol %s: %s", peer, info.protocolID, err) - hsData.validated = false - info.outboundHandshakeData.Store(peer, hsData) - return - } - - hsData.validated = true - info.outboundHandshakeData.Store(peer, hsData) - logger.Tracef("sender: validated handshake from peer %s using protocol %s", peer, info.protocolID) - } - if s.host.messageCache != nil { added, err := s.host.messageCache.put(peer, msg) if err != nil { @@ -349,18 +313,103 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc // we've completed the handshake with the peer, send message directly logger.Tracef("sending message to peer %s using protocol %s: %s", peer, info.protocolID, msg) - - err := s.host.writeToStream(hsData.stream, msg) - if err != nil { + if err := s.host.writeToStream(stream, msg); err != nil { logger.Debugf("failed to send message to peer %s: %s", peer, err) + + // the stream was closed or reset, close it on our end and delete it from our peer's data + if errors.Is(err, io.EOF) || errors.Is(err, mux.ErrReset) { + closeOutboundStream(info, peer, stream) + } + return } + logger.Tracef("successfully sent message on protocol %s to peer %s: message=", info.protocolID, peer, msg) s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{ Value: peerset.GossipSuccessValue, Reason: peerset.GossipSuccessReason, }, peer) } +func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsProtocol) (libp2pnetwork.Stream, error) { + mu, has := info.outboundHandshakeMutexes.Load(peer) + if !has { + // this should not happen + return nil, errMissingHandshakeMutex + } + + // multiple processes could each call this upcoming section, opening multiple streams and + // sending multiple handshakes. thus, we need to have a per-peer and per-protocol lock + mu.(*sync.Mutex).Lock() + defer mu.(*sync.Mutex).Unlock() + + hsData, has := info.getOutboundHandshakeData(peer) + switch { + case has && !hsData.validated: + // peer has sent us an invalid handshake in the past, ignore + return nil, errInvalidHandshakeForPeer + case has && hsData.validated: + return hsData.stream, nil + case !has: + hsData = newHandshakeData(false, false, nil) + } + + logger.Tracef("sending outbound handshake to peer %s on protocol %s, message: %s", + peer, info.protocolID, hs) + stream, err := s.host.send(peer, info.protocolID, hs) + if err != nil { + logger.Tracef("failed to send message to peer %s: %s", peer, err) + // don't need to close the stream here, as it's nil! + return nil, err + } + + hsData.stream = stream + + hsTimer := time.NewTimer(handshakeTimeout) + + var resp Handshake + select { + case <-hsTimer.C: + s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{ + Value: peerset.TimeOutValue, + Reason: peerset.TimeOutReason, + }, peer) + + logger.Tracef("handshake timeout reached for peer %s using protocol %s", peer, info.protocolID) + closeOutboundStream(info, peer, stream) + return nil, errHandshakeTimeout + case hsResponse := <-s.readHandshake(stream, info.handshakeDecoder): + if !hsTimer.Stop() { + <-hsTimer.C + } + + if hsResponse.err != nil { + logger.Tracef("failed to read handshake from peer %s using protocol %s: %s", peer, info.protocolID, err) + closeOutboundStream(info, peer, stream) + return nil, hsResponse.err + } + + resp = hsResponse.hs + hsData.received = true + } + + if err = info.handshakeValidator(peer, resp); err != nil { + logger.Tracef("failed to validate handshake from peer %s using protocol %s: %s", peer, info.protocolID, err) + hsData.validated = false + hsData.stream = nil + _ = stream.Reset() + info.outboundHandshakeData.Store(peer, hsData) + // don't delete handshake data, as we want to store that the handshake for this peer was invalid + // and not to exchange messages over this protocol with it + return nil, err + } + + hsData.validated = true + hsData.handshake = resp + info.outboundHandshakeData.Store(peer, hsData) + logger.Tracef("sender: validated handshake from peer %s using protocol %s", peer, info.protocolID) + return hsData.stream, nil +} + // broadcastExcluding sends a message to each connected peer except the given peer, // and peers that have previously sent us the message or who we have already sent the message to. // used for notifications sub-protocols to gossip a message diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 17bde80caf..0f2dc4b03a 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -4,7 +4,7 @@ package network import ( - "fmt" + "errors" "math/big" "reflect" "sync" @@ -21,10 +21,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestHandshake_SizeOf(t *testing.T) { - require.Equal(t, uint32(maxHandshakeSize), uint32(72)) -} - func TestCreateDecoder_BlockAnnounce(t *testing.T) { basePath := utils.NewTestBasePath(t, "nodeA") @@ -49,7 +45,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { // haven't received handshake from peer testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") - info.inboundHandshakeData.Store(testPeerID, handshakeData{ + info.inboundHandshakeData.Store(testPeerID, &handshakeData{ received: false, }) @@ -134,7 +130,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { handler := s.createNotificationsMessageHandler(info, s.handleBlockAnnounceMessage, nil) // set handshake data to received - info.inboundHandshakeData.Store(testPeerID, handshakeData{ + info.inboundHandshakeData.Store(testPeerID, &handshakeData{ received: true, validated: true, }) @@ -250,16 +246,14 @@ func Test_HandshakeTimeout(t *testing.T) { nodeB.noGossip = true // create info and handler - info := ¬ificationsProtocol{ - protocolID: nodeA.host.protocolID + blockAnnounceID, - getHandshake: nodeA.getBlockAnnounceHandshake, - handshakeValidator: nodeA.validateBlockAnnounceHandshake, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), + testHandshakeDecoder := func([]byte) (Handshake, error) { + return nil, errors.New("unimplemented") } + info := newNotificationsProtocol(nodeA.host.protocolID+blockAnnounceID, + nodeA.getBlockAnnounceHandshake, testHandshakeDecoder, nodeA.validateBlockAnnounceHandshake) nodeB.host.h.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) { - fmt.Println("never respond a handshake message") + // should not respond to a handshake message }) addrInfosB := nodeB.host.addrInfo() @@ -280,13 +274,14 @@ func Test_HandshakeTimeout(t *testing.T) { } nodeA.GossipMessage(testHandshakeMsg) + info.outboundHandshakeMutexes.Store(nodeB.host.id(), new(sync.Mutex)) go nodeA.sendData(nodeB.host.id(), testHandshakeMsg, info, nil) time.Sleep(time.Second) - // Verify that handshake data exists. + // handshake data shouldn't exist, as nodeB hasn't responded yet _, ok := info.getOutboundHandshakeData(nodeB.host.id()) - require.True(t, ok) + require.False(t, ok) // a stream should be open until timeout connAToB := nodeA.host.h.Network().ConnsToPeer(nodeB.host.id()) @@ -296,7 +291,7 @@ func Test_HandshakeTimeout(t *testing.T) { // after the timeout time.Sleep(handshakeTimeout) - // handshake data should be removed + // handshake data shouldn't exist still _, ok = info.getOutboundHandshakeData(nodeB.host.id()) require.False(t, ok) @@ -361,7 +356,7 @@ func TestCreateNotificationsMessageHandler_HandleTransaction(t *testing.T) { handler := s.createNotificationsMessageHandler(info, s.handleTransactionMessage, txnBatchHandler) // set handshake data to received - info.inboundHandshakeData.Store(testPeerID, handshakeData{ + info.inboundHandshakeData.Store(testPeerID, &handshakeData{ received: true, validated: true, }) diff --git a/dot/network/service.go b/dot/network/service.go index 0f51a2f049..58c0167954 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "io" "math/big" "strings" "sync" @@ -230,8 +229,26 @@ func (s *Service) Start() error { } // since this opens block announce streams, it should happen after the protocol is registered + // NOTE: this only handles *incoming* connections s.host.h.Network().SetConnHandler(s.handleConn) + // this handles all new connections (incoming and outgoing) + // it creates a per-protocol mutex for sending outbound handshakes to the peer + s.host.cm.connectHandler = func(peerID peer.ID) { + for _, prtl := range s.notificationsProtocols { + prtl.outboundHandshakeMutexes.Store(peerID, new(sync.Mutex)) + } + } + + // when a peer gets disconnected, we should clear all handshake data we have for it. + s.host.cm.disconnectHandler = func(peerID peer.ID) { + for _, prtl := range s.notificationsProtocols { + prtl.outboundHandshakeMutexes.Delete(peerID) + prtl.inboundHandshakeData.Delete(peerID) + prtl.outboundHandshakeData.Delete(peerID) + } + } + // log listening addresses to console for _, addr := range s.host.multiaddrs() { logger.Infof("Started listening on %s", addr) @@ -274,11 +291,24 @@ func (s *Service) collectNetworkMetrics() { totalConn := metrics.GetOrRegisterGauge("network/node/totalConnection", metrics.DefaultRegistry) networkLatency := metrics.GetOrRegisterGauge("network/node/latency", metrics.DefaultRegistry) syncedBlocks := metrics.GetOrRegisterGauge("service/blocks/sync", metrics.DefaultRegistry) + numInboundBlockAnnounceStreams := metrics.GetOrRegisterGauge("network/streams/block_announce/inbound", metrics.DefaultRegistry) + numOutboundBlockAnnounceStreams := metrics.GetOrRegisterGauge("network/streams/block_announce/outbound", metrics.DefaultRegistry) + numInboundGrandpaStreams := metrics.GetOrRegisterGauge("network/streams/grandpa/inbound", metrics.DefaultRegistry) + numOutboundGrandpaStreams := metrics.GetOrRegisterGauge("network/streams/grandpa/outbound", metrics.DefaultRegistry) + totalInboundStreams := metrics.GetOrRegisterGauge("network/streams/total/inbound", metrics.DefaultRegistry) + totalOutboundStreams := metrics.GetOrRegisterGauge("network/streams/total/outbound", metrics.DefaultRegistry) peerCount.Update(int64(s.host.peerCount())) totalConn.Update(int64(len(s.host.h.Network().Conns()))) networkLatency.Update(int64(s.host.h.Peerstore().LatencyEWMA(s.host.id()))) + numInboundBlockAnnounceStreams.Update(s.getNumStreams(BlockAnnounceMsgType, true)) + numOutboundBlockAnnounceStreams.Update(s.getNumStreams(BlockAnnounceMsgType, false)) + numInboundGrandpaStreams.Update(s.getNumStreams(ConsensusMsgType, true)) + numOutboundGrandpaStreams.Update(s.getNumStreams(ConsensusMsgType, false)) + totalInboundStreams.Update(s.getTotalStreams(true)) + totalOutboundStreams.Update(s.getTotalStreams(false)) + num, err := s.blockState.BestBlockNumber() if err != nil { syncedBlocks.Update(0) @@ -290,6 +320,46 @@ func (s *Service) collectNetworkMetrics() { } } +func (s *Service) getTotalStreams(inbound bool) (count int64) { + for _, conn := range s.host.h.Network().Conns() { + for _, stream := range conn.GetStreams() { + streamIsInbound := isInbound(stream) + if (streamIsInbound && inbound) || (!streamIsInbound && !inbound) { + count++ + } + } + } + return count +} + +func (s *Service) getNumStreams(protocolID byte, inbound bool) (count int64) { + np, has := s.notificationsProtocols[protocolID] + if !has { + return 0 + } + + var hsData *sync.Map + if inbound { + hsData = np.inboundHandshakeData + } else { + hsData = np.outboundHandshakeData + } + + hsData.Range(func(_, data interface{}) bool { + if data == nil { + return true + } + + if data.(*handshakeData).stream != nil { + count++ + } + + return true + }) + + return count +} + func (s *Service) logPeerCount() { ticker := time.NewTicker(time.Second * 30) defer ticker.Stop() @@ -418,46 +488,13 @@ func (s *Service) RegisterNotificationsProtocol( return errors.New("notifications protocol with message type already exists") } - np := ¬ificationsProtocol{ - protocolID: protocolID, - getHandshake: handshakeGetter, - handshakeValidator: handshakeValidator, - handshakeDecoder: handshakeDecoder, - inboundHandshakeData: new(sync.Map), - outboundHandshakeData: new(sync.Map), - } + np := newNotificationsProtocol(protocolID, handshakeGetter, handshakeDecoder, handshakeValidator) s.notificationsProtocols[messageID] = np - - connMgr := s.host.h.ConnManager().(*ConnManager) - connMgr.registerCloseHandler(protocolID, func(peerID peer.ID) { - if _, ok := np.getInboundHandshakeData(peerID); ok { - logger.Tracef( - "Cleaning up inbound handshake data for peer %s and protocol %s", - peerID, protocolID) - np.inboundHandshakeData.Delete(peerID) - } - - if _, ok := np.getOutboundHandshakeData(peerID); ok { - logger.Tracef( - "Cleaning up outbound handshake data for peer %s and protocol %s", - peerID, protocolID) - np.outboundHandshakeData.Delete(peerID) - } - }) - - info := s.notificationsProtocols[messageID] - - decoder := createDecoder(info, handshakeDecoder, messageDecoder) - handlerWithValidate := s.createNotificationsMessageHandler(info, messageHandler, batchHandler) + decoder := createDecoder(np, handshakeDecoder, messageDecoder) + handlerWithValidate := s.createNotificationsMessageHandler(np, messageHandler, batchHandler) s.host.registerStreamHandler(protocolID, func(stream libp2pnetwork.Stream) { logger.Tracef("received stream using sub-protocol %s", protocolID) - conn := stream.Conn() - if conn == nil { - logger.Error("Failed to get connection from stream") - return - } - s.readStream(stream, decoder, handlerWithValidate) }) @@ -517,120 +554,6 @@ func (s *Service) SendMessage(to peer.ID, msg NotificationsMessage) error { return errors.New("message not supported by any notifications protocol") } -// handleLightStream handles streams with the /light/2 protocol ID -func (s *Service) handleLightStream(stream libp2pnetwork.Stream) { - s.readStream(stream, s.decodeLightMessage, s.handleLightMsg) -} - -func (s *Service) decodeLightMessage(in []byte, peer peer.ID, _ bool) (Message, error) { - s.lightRequestMu.RLock() - defer s.lightRequestMu.RUnlock() - - // check if we are the requester - if _, requested := s.lightRequest[peer]; requested { - // if we are, decode the bytes as a LightResponse - msg := NewLightResponse() - err := msg.Decode(in) - return msg, err - } - - // otherwise, decode bytes as LightRequest - msg := NewLightRequest() - err := msg.Decode(in) - return msg, err -} - -func isInbound(stream libp2pnetwork.Stream) bool { - return stream.Stat().Direction == libp2pnetwork.DirInbound -} - -func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) { - s.streamManager.logNewStream(stream) - - peer := stream.Conn().RemotePeer() - msgBytes := s.bufPool.get() - defer s.bufPool.put(msgBytes) - - for { - tot, err := readStream(stream, msgBytes[:]) - if errors.Is(err, io.EOF) { - return - } else if err != nil { - logger.Tracef( - "failed to read from stream id %s of peer %s using protocol %s: %s", - stream.ID(), stream.Conn().RemotePeer(), stream.Protocol(), err) - _ = stream.Close() - return - } - - s.streamManager.logMessageReceived(stream.ID()) - - // decode message based on message type - msg, err := decoder(msgBytes[:tot], peer, isInbound(stream)) - if err != nil { - logger.Tracef("failed to decode message from stream id %s using protocol %s: %s", - stream.ID(), stream.Protocol(), err) - continue - } - - logger.Tracef( - "host %s received message from peer %s: %s", - s.host.id(), peer, msg.String()) - - err = handler(stream, msg) - if err != nil { - logger.Tracef("failed to handle message %s from stream id %s: %s", msg, stream.ID(), err) - _ = stream.Close() - return - } - - s.host.bwc.LogRecvMessage(int64(tot)) - } -} - -func (s *Service) handleLightMsg(stream libp2pnetwork.Stream, msg Message) error { - defer func() { - _ = stream.Close() - }() - - lr, ok := msg.(*LightRequest) - if !ok { - return nil - } - - resp := NewLightResponse() - var err error - switch { - case lr.RemoteCallRequest != nil: - resp.RemoteCallResponse, err = remoteCallResp(lr.RemoteCallRequest) - case lr.RemoteHeaderRequest != nil: - resp.RemoteHeaderResponse, err = remoteHeaderResp(lr.RemoteHeaderRequest) - case lr.RemoteChangesRequest != nil: - resp.RemoteChangesResponse, err = remoteChangeResp(lr.RemoteChangesRequest) - case lr.RemoteReadRequest != nil: - resp.RemoteReadResponse, err = remoteReadResp(lr.RemoteReadRequest) - case lr.RemoteReadChildRequest != nil: - resp.RemoteReadResponse, err = remoteReadChildResp(lr.RemoteReadChildRequest) - default: - logger.Warn("ignoring LightRequest without request data") - return nil - } - - if err != nil { - logger.Errorf("failed to get the response: %s", err) - return err - } - - // TODO(arijit): Remove once we implement the internal APIs. Added to increase code coverage. (#1856) - logger.Debugf("LightResponse message: %s", resp) - - err = s.host.writeToStream(stream, resp) - if err != nil { - logger.Warnf("failed to send LightResponse message to peer %s: %s", stream.Conn().RemotePeer(), err) - } - return err -} - // Health returns information about host needed for the rpc server func (s *Service) Health() common.Health { return common.Health{ diff --git a/dot/network/service_test.go b/dot/network/service_test.go index f4a5b0f48b..9570efe7e8 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -211,7 +211,7 @@ func TestBroadcastDuplicateMessage(t *testing.T) { require.NotNil(t, stream) protocol := nodeA.notificationsProtocols[BlockAnnounceMsgType] - protocol.outboundHandshakeData.Store(nodeB.host.id(), handshakeData{ + protocol.outboundHandshakeData.Store(nodeB.host.id(), &handshakeData{ received: true, validated: true, stream: stream, diff --git a/dot/network/utils.go b/dot/network/utils.go index ceff586422..fb16633a15 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -20,6 +20,10 @@ import ( "github.com/multiformats/go-multiaddr" ) +func isInbound(stream libp2pnetwork.Stream) bool { + return stream.Stat().Direction == libp2pnetwork.DirInbound +} + // stringToAddrInfos converts a single string peer id to AddrInfo func stringToAddrInfo(s string) (peer.AddrInfo, error) { maddr, err := multiaddr.NewMultiaddr(s) diff --git a/dot/services_test.go b/dot/services_test.go index d958449cc4..a2e34bec3b 100644 --- a/dot/services_test.go +++ b/dot/services_test.go @@ -269,7 +269,10 @@ func TestCreateGrandpaService(t *testing.T) { dh, err := createDigestHandler(stateSrvc) require.NoError(t, err) - gs, err := createGRANDPAService(cfg, stateSrvc, dh, ks.Gran, &network.Service{}) + networkSrvc, err := createNetworkService(cfg, stateSrvc) + require.NoError(t, err) + + gs, err := createGRANDPAService(cfg, stateSrvc, dh, ks.Gran, networkSrvc) require.NoError(t, err) require.NotNil(t, gs) } diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 5b527817d3..7b0741f8e6 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rand" "errors" + "fmt" "math/big" "strings" "sync" @@ -853,7 +854,7 @@ func (cs *chainSync) validateBlockData(req *network.BlockRequestMessage, bd *typ } if (requestedData&network.RequestedDataBody>>1) == 1 && bd.Body == nil { - return errNilBodyInResponse + return fmt.Errorf("%w: hash=%s", errNilBodyInResponse, bd.Hash) } return nil diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index 0c3131c6de..213733b244 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -472,7 +472,7 @@ func TestValidateBlockData(t *testing.T) { err = cs.validateBlockData(req, &types.BlockData{ Header: &types.Header{}, }, "") - require.Equal(t, errNilBodyInResponse, err) + require.ErrorIs(t, err, errNilBodyInResponse) err = cs.validateBlockData(req, &types.BlockData{ Header: &types.Header{}, diff --git a/dot/types/grandpa.go b/dot/types/grandpa.go index 00df39f626..6d098d834e 100644 --- a/dot/types/grandpa.go +++ b/dot/types/grandpa.go @@ -76,7 +76,7 @@ func (gv *GrandpaVoter) PublicKeyBytes() ed25519.PublicKeyBytes { // String returns a formatted GrandpaVoter string func (gv *GrandpaVoter) String() string { - return fmt.Sprintf("[key=0x%s id=%d]", gv.PublicKeyBytes(), gv.ID) + return fmt.Sprintf("[key=%s id=%d]", gv.PublicKeyBytes(), gv.ID) } // NewGrandpaVotersFromAuthorities returns an array of GrandpaVoters given an array of GrandpaAuthorities diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index 194d85edfb..4f623cb08a 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -160,6 +160,10 @@ func NewService(cfg *Config) (*Service, error) { interval: cfg.Interval, } + if err := s.registerProtocol(); err != nil { + return nil, err + } + s.messageHandler = NewMessageHandler(s, s.blockState) s.tracker = newTracker(s.blockState, s.messageHandler) s.paused.Store(false) @@ -168,11 +172,6 @@ func NewService(cfg *Config) (*Service, error) { // Start begins the GRANDPA finality service func (s *Service) Start() error { - // TODO: determine if we need to send a catch-up request (#1531) - if err := s.registerProtocol(); err != nil { - return err - } - // if we're not an authority, we don't need to worry about the voting process. // the grandpa service is only used to verify incoming block justifications if !s.authority {