From 803185e5342ef507b4033f2d4dcdad92a6eb4a3c Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Thu, 9 May 2024 18:43:26 -0500 Subject: [PATCH] server: Make manual peer disconnect synchronous. This refactors the logic for manually disconnecting peers out of the peer handler since it no longer needs to be plumbed through the query channel. This is a part of the overall effort to convert all of the code related to updating and querying the server's peer state to synchronous code that makes use of a separate mutex to protect it. --- rpcadaptors.go | 56 +++++++++++++++++++++++++++++++++++++++----------- server.go | 42 ------------------------------------- 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/rpcadaptors.go b/rpcadaptors.go index 1d86f07f0..e0fa96cd2 100644 --- a/rpcadaptors.go +++ b/rpcadaptors.go @@ -228,6 +228,46 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { return nil } +// disconnectNode disconnects any peers that the provided compare function +// returns true for. It applies to both inbound and outbound peers. +// +// An error will be returned if no matching peers are found (aka the compare +// function returns false for all peers). +// +// This function is safe for concurrent access. +func (cm *rpcConnManager) disconnectNode(cmp func(sp *serverPeer) bool) error { + state := &cm.server.peerState + defer state.Unlock() + state.Lock() + + // Check inbound peers. No callback is passed since there are no additional + // actions on disconnect for inbound peers. + found := disconnectPeer(state.inboundPeers, cmp, nil) + if found { + return nil + } + + // Check outbound peers in a loop to ensure all outbound connections to the + // same ip:port are disconnected when there are multiple. + var numFound uint32 + for ; ; numFound++ { + found = disconnectPeer(state.outboundPeers, cmp, func(sp *serverPeer) { + // Update the group counts since the peer will be removed from the + // persistent peers just after this func returns. + remoteAddr := wireToAddrmgrNetAddress(sp.NA()) + state.outboundGroups[remoteAddr.GroupKey()]-- + }) + if !found { + break + } + } + + if numFound == 0 { + return errors.New("peer not found") + } + return nil +} + // DisconnectByID disconnects the peer associated with the provided id. This // applies to both inbound and outbound peers. Attempting to remove an id that // does not exist will return an error. @@ -235,12 +275,8 @@ func (cm *rpcConnManager) RemoveByAddr(addr string) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByID(id int32) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.ID() == id }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.ID() == id } + return cm.disconnectNode(cmp) } // DisconnectByAddr disconnects the peer associated with the provided address. @@ -250,12 +286,8 @@ func (cm *rpcConnManager) DisconnectByID(id int32) error { // This function is safe for concurrent access and is part of the // rpcserver.ConnManager interface implementation. func (cm *rpcConnManager) DisconnectByAddr(addr string) error { - replyChan := make(chan error) - cm.server.query <- disconnectNodeMsg{ - cmp: func(sp *serverPeer) bool { return sp.Addr() == addr }, - reply: replyChan, - } - return <-replyChan + cmp := func(sp *serverPeer) bool { return sp.Addr() == addr } + return cm.disconnectNode(cmp) } // ConnectedCount returns the number of currently connected peers. diff --git a/server.go b/server.go index 8ed9ee326..bed64f53d 100644 --- a/server.go +++ b/server.go @@ -2088,51 +2088,9 @@ func (s *server) handleBroadcastMsg(state *peerState, bmsg *broadcastMsg) { }) } -type disconnectNodeMsg struct { - cmp func(*serverPeer) bool - reply chan error -} - // handleQuery is the central handler for all queries and commands from other // goroutines related to peer state. func (s *server) handleQuery(ctx context.Context, state *peerState, querymsg interface{}) { - switch msg := querymsg.(type) { - case disconnectNodeMsg: - // Check inbound peers. We pass a nil callback since we don't - // require any additional actions on disconnect for inbound peers. - state.Lock() - found := disconnectPeer(state.inboundPeers, msg.cmp, nil) - if found { - state.Unlock() - msg.reply <- nil - return - } - - // Check outbound peers. - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - // Keep group counts ok since we remove from - // the list now. - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - if found { - // If there are multiple outbound connections to the same - // ip:port, continue disconnecting them all until no such - // peers are found. - for found { - found = disconnectPeer(state.outboundPeers, msg.cmp, func(sp *serverPeer) { - remoteAddr := wireToAddrmgrNetAddress(sp.NA()) - state.outboundGroups[remoteAddr.GroupKey()]-- - }) - } - state.Unlock() - msg.reply <- nil - return - } - state.Unlock() - - msg.reply <- errors.New("peer not found") - } } // disconnectPeer attempts to drop the connection of a targeted peer in the