diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index edb55d371ce..05e7d807ab2 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -626,7 +626,7 @@ fn connection_manager_logger( println!("'{}' connected to '{}'", node_name, get_name(conn.peer_node_id()),); }, }, - PeerDisconnected(node_id) => { + PeerDisconnected(_, node_id) => { println!("'{}' disconnected from '{}'", get_name(node_id), node_name); }, PeerConnectFailed(node_id, err) => { diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index 22c368906ed..60b1cb2ce0b 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -280,7 +280,7 @@ async fn peer_to_peer_messaging() { #[runtime::test] async fn peer_to_peer_messaging_simultaneous() { - const NUM_MSGS: usize = 10; + const NUM_MSGS: usize = 100; let shutdown = Shutdown::new(); let (comms_node1, mut inbound_rx1, outbound_tx1, _) = spawn_node(Protocols::new(), shutdown.to_signal()).await; @@ -324,6 +324,11 @@ async fn peer_to_peer_messaging_simultaneous() { .await .unwrap(); + comms_node1 + .connectivity() + .dial_peer(comms_node2.node_identity().node_id().clone()) + .await + .unwrap(); // Simultaneously send messages between the two nodes let handle1 = task::spawn(async move { for i in 0..NUM_MSGS { diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index 1d0bb6b4964..bf027b24261 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -43,7 +43,7 @@ use super::{ }; use crate::{ backoff::Backoff, - connection_manager::{metrics, ConnectionDirection}, + connection_manager::{metrics, ConnectionDirection, ConnectionId}, multiplexing::Substream, noise::NoiseConfig, peer_manager::{NodeId, NodeIdentity, PeerManagerError}, @@ -61,7 +61,7 @@ const DIALER_REQUEST_CHANNEL_SIZE: usize = 32; pub enum ConnectionManagerEvent { // Peer connection PeerConnected(PeerConnection), - PeerDisconnected(NodeId), + PeerDisconnected(ConnectionId, NodeId), PeerConnectFailed(NodeId, ConnectionManagerError), PeerInboundConnectFailed(ConnectionManagerError), @@ -74,7 +74,7 @@ impl fmt::Display for ConnectionManagerEvent { use ConnectionManagerEvent::*; match self { PeerConnected(conn) => write!(f, "PeerConnected({})", conn), - PeerDisconnected(node_id) => write!(f, "PeerDisconnected({})", node_id.short_str()), + PeerDisconnected(id, node_id) => write!(f, "PeerDisconnected({}, {})", id, node_id.short_str()), PeerConnectFailed(node_id, err) => write!(f, "PeerConnectFailed({}, {:?})", node_id.short_str(), err), PeerInboundConnectFailed(err) => write!(f, "PeerInboundConnectFailed({:?})", err), NewInboundSubstream(node_id, protocol, _) => write!( diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index e6fd0d90e58..d2a8d415ed5 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -484,12 +484,29 @@ impl PeerConnectionActor { /// /// silent - true to suppress the PeerDisconnected event, false to publish the event async fn disconnect(&mut self, silent: bool) -> Result<(), PeerConnectionError> { - if !silent { - self.notify_event(ConnectionManagerEvent::PeerDisconnected(self.peer_node_id.clone())) - .await; + match self.control.close().await { + Err(yamux::ConnectionError::Closed) => { + debug!( + target: LOG_TARGET, + "(Peer = {}) Connection already closed", + self.peer_node_id.short_str() + ); + + return Ok(()); + }, + // Only emit closed event once + _ => { + if !silent { + self.notify_event(ConnectionManagerEvent::PeerDisconnected( + self.id, + self.peer_node_id.clone(), + )) + .await; + } + }, } - self.control.close().await?; + self.request_rx.close(); debug!( target: LOG_TARGET, diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index f37b113cd56..d99f6a1b163 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -490,73 +490,79 @@ impl ConnectivityManagerActor { event: &ConnectionManagerEvent, ) -> Result<(), ConnectivityError> { use ConnectionManagerEvent::*; - #[allow(clippy::single_match)] + debug!(target: LOG_TARGET, "Received event: {}", event); match event { PeerConnected(new_conn) => { - // self.connection_manager - // .cancel_dial(new_conn.peer_node_id().clone()) - // .await?; - - match self.pool.get_connection(new_conn.peer_node_id()) { + match self.pool.get_connection(new_conn.peer_node_id()).cloned() { Some(existing_conn) if !existing_conn.is_connected() => { debug!( target: LOG_TARGET, - "Tie break: Existing connection was not connected, resolving tie break by using the new \ - connection. (New={}, Existing={})", - new_conn, - existing_conn, - ); - }, - Some(existing_conn) if existing_conn.age() >= Duration::from_secs(60) => { - debug!( - target: LOG_TARGET, - "Tie break: Existing connection is reported as connected however the authenticated peer \ - is still attempting to connect to us. Resolving tie break by using the new connection. \ - (New={}, Existing={})", - new_conn, - existing_conn, + "Tie break: Existing connection (id: {}, peer: {}, direction: {}) was not connected, \ + resolving tie break by using the new connection. (New: id: {}, peer: {}, direction: {})", + existing_conn.id(), + existing_conn.peer_node_id(), + existing_conn.direction(), + new_conn.id(), + new_conn.peer_node_id(), + new_conn.direction(), ); - let node_id = existing_conn.peer_node_id().clone(); - let direction = existing_conn.direction(); - delayed_close(existing_conn.clone(), self.config.connection_tie_break_linger); - self.publish_event(ConnectivityEvent::PeerConnectionWillClose(node_id, direction)); + self.pool.remove(existing_conn.peer_node_id()); }, - Some(existing_conn) if self.tie_break_existing_connection(existing_conn, new_conn) => { + Some(mut existing_conn) if self.tie_break_existing_connection(&existing_conn, new_conn) => { debug!( target: LOG_TARGET, - "Tie break: (Peer = {}) Keep new {} connection, Disconnect existing {} connection", - new_conn.peer_node_id().short_str(), + "Tie break: Keep new connection (id: {}, peer: {}, direction: {}). Disconnect existing \ + connection (id: {}, peer: {}, direction: {})", + new_conn.id(), + new_conn.peer_node_id(), new_conn.direction(), - existing_conn.direction() + existing_conn.id(), + existing_conn.peer_node_id(), + existing_conn.direction(), ); - let node_id = existing_conn.peer_node_id().clone(); - let direction = existing_conn.direction(); - delayed_close(existing_conn.clone(), self.config.connection_tie_break_linger); - self.publish_event(ConnectivityEvent::PeerConnectionWillClose(node_id, direction)); + let _ = existing_conn.disconnect_silent().await; + self.pool.remove(existing_conn.peer_node_id()); }, Some(existing_conn) => { debug!( target: LOG_TARGET, - "Tie break: (Peer = {}) Keeping existing {} connection, Disconnecting new {} connection", - existing_conn.peer_node_id().short_str(), - existing_conn.direction(), + "Tie break: Keeping existing connection (id: {}, peer: {}, direction: {}). Disconnecting \ + new connection (id: {}, peer: {}, direction: {})", + new_conn.id(), + new_conn.peer_node_id(), new_conn.direction(), + existing_conn.id(), + existing_conn.peer_node_id(), + existing_conn.direction(), ); - delayed_close(new_conn.clone(), self.config.connection_tie_break_linger); + let _ = new_conn.clone().disconnect_silent().await; // Ignore this event - state can stay as is return Ok(()); }, + _ => {}, } }, - + PeerDisconnected(id, node_id) => { + if let Some(conn) = self.pool.get_connection(node_id) { + if conn.id() != *id { + debug!( + target: LOG_TARGET, + "Ignoring peer disconnected event for stale peer connection (id: {}) for peer '{}'", + id, + node_id + ); + return Ok(()); + } + } + }, _ => {}, } let (node_id, mut new_status, connection) = match event { - PeerDisconnected(node_id) => (&*node_id, ConnectionStatus::Disconnected, None), + PeerDisconnected(_, node_id) => (&*node_id, ConnectionStatus::Disconnected, None), PeerConnected(conn) => (conn.peer_node_id(), ConnectionStatus::Connected, Some(conn.clone())), PeerConnectFailed(node_id, ConnectionManagerError::DialCancelled) => { @@ -649,8 +655,8 @@ impl ConnectivityManagerActor { (Inbound, Outbound) => peer_node_id > our_node_id, // We connected to them at the same time as they connected to us (Outbound, Inbound) => our_node_id > peer_node_id, - // We connected to them twice for some reason. Drop the newer connection. - (Outbound, Outbound) => false, + // We connected to them twice for some reason. Drop the older connection. + (Outbound, Outbound) => true, } } @@ -818,16 +824,3 @@ impl ConnectivityManagerActor { } } } - -fn delayed_close(conn: PeerConnection, delay: Duration) { - task::spawn(async move { - time::sleep(delay).await; - debug!( - target: LOG_TARGET, - "Closing connection from peer `{}` after delay", - conn.peer_node_id() - ); - // Can ignore the error here, the error is already logged by peer connection - let _ = conn.clone().disconnect_silent().await; - }); -} diff --git a/comms/src/connectivity/requester.rs b/comms/src/connectivity/requester.rs index 50dca7dda9d..66451c5e6cc 100644 --- a/comms/src/connectivity/requester.rs +++ b/comms/src/connectivity/requester.rs @@ -39,11 +39,7 @@ use super::{ manager::ConnectivityStatus, ConnectivitySelection, }; -use crate::{ - connection_manager::{ConnectionDirection, ConnectionManagerError}, - peer_manager::NodeId, - PeerConnection, -}; +use crate::{connection_manager::ConnectionManagerError, peer_manager::NodeId, PeerConnection}; const LOG_TARGET: &str = "comms::connectivity::requester"; @@ -57,7 +53,6 @@ pub enum ConnectivityEvent { PeerConnectFailed(NodeId), PeerBanned(NodeId), PeerOffline(NodeId), - PeerConnectionWillClose(NodeId, ConnectionDirection), ConnectivityStateInitialized, ConnectivityStateOnline(usize), @@ -74,9 +69,6 @@ impl fmt::Display for ConnectivityEvent { PeerConnectFailed(node_id) => write!(f, "PeerConnectFailed({})", node_id), PeerBanned(node_id) => write!(f, "PeerBanned({})", node_id), PeerOffline(node_id) => write!(f, "PeerOffline({})", node_id), - PeerConnectionWillClose(node_id, direction) => { - write!(f, "PeerConnectionWillClose({}, {})", node_id, direction) - }, ConnectivityStateInitialized => write!(f, "ConnectivityStateInitialized"), ConnectivityStateOnline(n) => write!(f, "ConnectivityStateOnline({})", n), ConnectivityStateDegraded(n) => write!(f, "ConnectivityStateDegraded({})", n), diff --git a/comms/src/connectivity/test.rs b/comms/src/connectivity/test.rs index 23ec64a3e4a..d75030e0bea 100644 --- a/comms/src/connectivity/test.rs +++ b/comms/src/connectivity/test.rs @@ -200,7 +200,10 @@ async fn online_then_offline() { )); for conn in connections.iter().skip(1) { - cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected(conn.peer_node_id().clone())); + cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected( + conn.id(), + conn.peer_node_id().clone(), + )); } streams::assert_in_broadcast( @@ -218,7 +221,10 @@ async fn online_then_offline() { // Disconnect client connections for conn in &client_connections { - cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected(conn.peer_node_id().clone())); + cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected( + conn.id(), + conn.peer_node_id().clone(), + )); } streams::assert_in_broadcast( @@ -389,7 +395,10 @@ async fn pool_management() { assert_eq!(conn.handle_count(), 2); // The peer connection mock does not "automatically" publish event to connectivity manager conn.disconnect().await.unwrap(); - cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected(conn.peer_node_id().clone())); + cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected( + conn.id(), + conn.peer_node_id().clone(), + )); } } @@ -407,6 +416,7 @@ async fn pool_management() { assert_eq!(conns.len(), 1); important_connection.disconnect().await.unwrap(); cm_mock_state.publish_event(ConnectionManagerEvent::PeerDisconnected( + important_connection.id(), important_connection.peer_node_id().clone(), )); drop(important_connection); diff --git a/comms/src/protocol/messaging/config.rs b/comms/src/protocol/messaging/config.rs deleted file mode 100644 index 59554280197..00000000000 --- a/comms/src/protocol/messaging/config.rs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2020, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use std::time::Duration; - -#[derive(Debug, Clone)] -pub struct MessagingConfig { - /// The length of time that inactivity is allowed before closing the inbound/outbound substreams, or None for no - /// timeout - /// - /// Inbound/outbound substreams are closed independently, and they may be reopened in the future once closed. - /// (default: 8 mins) - pub inactivity_timeout: Option, -} - -impl Default for MessagingConfig { - fn default() -> Self { - Self { - inactivity_timeout: Some(Duration::from_secs(8 * 60)), - } - } -} diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index 2572b254397..e7fd9f5fb58 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -55,6 +55,4 @@ pub enum MessagingProtocolError { Io(#[from] io::Error), #[error("Sender error: {0}")] SenderError(#[from] mpsc::error::SendError), - #[error("Stream closed due to inactivity")] - Inactivity, } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 45cbce8007e..fe656e8de3b 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -81,7 +81,6 @@ where let (inbound_message_tx, inbound_message_rx) = mpsc::channel(INBOUND_MESSAGE_BUFFER_SIZE); let messaging = MessagingProtocol::new( - Default::default(), context.connectivity(), proto_rx, messaging_request_rx, diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs index 30be7a8a6bb..7492323dc77 100644 --- a/comms/src/protocol/messaging/inbound.rs +++ b/comms/src/protocol/messaging/inbound.rs @@ -22,7 +22,7 @@ use std::{sync::Arc, time::Duration}; -use futures::{future::Either, StreamExt}; +use futures::StreamExt; use log::*; use tokio::{ io::{AsyncRead, AsyncWrite}, @@ -40,7 +40,6 @@ pub struct InboundMessaging { messaging_events_tx: broadcast::Sender>, rate_limit_capacity: usize, rate_limit_restock_interval: Duration, - inactivity_timeout: Option, } impl InboundMessaging { @@ -50,7 +49,6 @@ impl InboundMessaging { messaging_events_tx: broadcast::Sender>, rate_limit_capacity: usize, rate_limit_restock_interval: Duration, - inactivity_timeout: Option, ) -> Self { Self { peer, @@ -58,7 +56,6 @@ impl InboundMessaging { messaging_events_tx, rate_limit_capacity, rate_limit_restock_interval, - inactivity_timeout, } } @@ -75,16 +72,12 @@ impl InboundMessaging { let stream = MessagingProtocol::framed(socket).rate_limit(self.rate_limit_capacity, self.rate_limit_restock_interval); - let stream = match self.inactivity_timeout { - Some(timeout) => Either::Left(tokio_stream::StreamExt::timeout(stream, timeout)), - None => Either::Right(stream.map(Ok)), - }; tokio::pin!(stream); let inbound_count = metrics::inbound_message_count(&self.peer); while let Some(result) = stream.next().await { match result { - Ok(Ok(raw_msg)) => { + Ok(raw_msg) => { inbound_count.inc(); let msg_len = raw_msg.len(); let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze()); @@ -112,7 +105,7 @@ impl InboundMessaging { let _ = self.messaging_events_tx.send(Arc::new(event)); }, - Ok(Err(err)) => { + Err(err) => { metrics::error_count(peer).inc(); error!( target: LOG_TARGET, @@ -122,18 +115,6 @@ impl InboundMessaging { ); break; }, - - Err(_) => { - metrics::error_count(peer).inc(); - debug!( - target: LOG_TARGET, - "Inbound messaging for peer '{}' has stopped because it was inactive for {:.0?}", - peer.short_str(), - self.inactivity_timeout - .expect("Inactivity timeout reached but it was not enabled"), - ); - break; - }, } } diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 7fd5703e50a..bdbf4eb0db4 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -20,9 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod config; -pub use config::MessagingConfig; - mod extension; pub use extension::MessagingProtocolExtension; diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index 2220373e2bd..9865b1fd059 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -20,10 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::time::{Duration, Instant}; +use std::time::Instant; -use futures::{future::Either, SinkExt, StreamExt, TryStreamExt}; -use tokio::sync::mpsc; +use futures::{future, SinkExt, StreamExt}; +use tokio::{pin, sync::mpsc}; use tracing::{debug, error, event, span, Instrument, Level}; use super::{error::MessagingProtocolError, metrics, MessagingEvent, MessagingProtocol, SendFailReason}; @@ -34,6 +34,7 @@ use crate::{ multiplexing::Substream, peer_manager::NodeId, protocol::messaging::protocol::MESSAGING_PROTOCOL, + stream_id::StreamId, }; const LOG_TARGET: &str = "comms::protocol::messaging::outbound"; @@ -48,7 +49,6 @@ pub struct OutboundMessaging { messaging_events_tx: mpsc::Sender, retry_queue_tx: mpsc::UnboundedSender, peer_node_id: NodeId, - inactivity_timeout: Option, } impl OutboundMessaging { @@ -58,7 +58,6 @@ impl OutboundMessaging { messages_rx: mpsc::UnboundedReceiver, retry_queue_tx: mpsc::UnboundedSender, peer_node_id: NodeId, - inactivity_timeout: Option, ) -> Self { Self { connectivity, @@ -66,7 +65,6 @@ impl OutboundMessaging { messaging_events_tx, retry_queue_tx, peer_node_id, - inactivity_timeout, } } @@ -97,17 +95,6 @@ impl OutboundMessaging { "Outbound messaging for peer '{}' has stopped because the stream was closed", peer_node_id ); }, - Err(MessagingProtocolError::Inactivity) => { - event!( - Level::DEBUG, - "Outbound messaging for peer '{}' has stopped because it was inactive", - peer_node_id - ); - debug!( - target: LOG_TARGET, - "Outbound messaging for peer '{}' has stopped because it was inactive", peer_node_id - ); - }, Err(MessagingProtocolError::PeerDialFailed(err)) => { debug!( target: LOG_TARGET, @@ -263,7 +250,6 @@ impl OutboundMessaging { ) -> Result<(), MessagingProtocolError> { let Self { mut messages_rx, - inactivity_timeout, peer_node_id, .. } = self; @@ -273,34 +259,31 @@ impl OutboundMessaging { node_id = peer_node_id.to_string().as_str() ); let _enter = span.enter(); + let stream_id = substream.stream.stream_id(); debug!( target: LOG_TARGET, - "Starting direct message forwarding for peer `{}`", peer_node_id + "Starting direct message forwarding for peer `{}` (stream: {})", peer_node_id, stream_id ); - let framed = MessagingProtocol::framed(substream.stream); + let (sink, mut remote_stream) = MessagingProtocol::framed(substream.stream).split(); // Convert unbounded channel to a stream - let stream = futures::stream::unfold(&mut messages_rx, |rx| async move { + let outbound_stream = futures::stream::unfold(&mut messages_rx, |rx| async move { let v = rx.recv().await; v.map(|v| (v, rx)) }); - let outbound_stream = match inactivity_timeout { - Some(timeout) => Either::Left( - tokio_stream::StreamExt::timeout(stream, timeout).map_err(|_| MessagingProtocolError::Inactivity), - ), - None => Either::Right(stream.map(Ok)), - }; - let outbound_count = metrics::outbound_message_count(&peer_node_id); - let stream = outbound_stream.map(|msg| { + let stream = outbound_stream.map(|mut out_msg| { outbound_count.inc(); - msg.map(|mut out_msg| { - event!(Level::DEBUG, "Message buffered for sending {}", out_msg); - out_msg.reply_success(); - out_msg.body - }) + event!( + Level::DEBUG, + "Message buffered for sending {} on stream {}", + out_msg, + stream_id + ); + out_msg.reply_success(); + Result::<_, MessagingProtocolError>::Ok(out_msg.body) }); // Stop the stream as soon as the disconnection occurs, this allows the outbound stream to terminate as soon as @@ -311,14 +294,18 @@ impl OutboundMessaging { // We drop the conn handle here BEFORE awaiting a disconnect to ensure that the outbound messaging isn't // holding onto the handle keeping the connection alive drop(conn); - on_disconnect.await; + // Read from the yamux socket to determine if it is closed. + let close_detect = remote_stream.next(); + pin!(on_disconnect); + pin!(close_detect); + future::select(on_disconnect, close_detect).await; debug!( target: LOG_TARGET, - "Peer connection closed. Ending outbound messaging stream for peer {}.", peer_node_id + "Outbound messaging stream {} ended for peer {}.", stream_id, peer_node_id ) }); - super::forward::Forward::new(stream, framed.sink_map_err(Into::into)).await?; + super::forward::Forward::new(stream, sink.sink_map_err(Into::into)).await?; // Close so that the protocol handler does not resend to this session messages_rx.close(); @@ -343,7 +330,7 @@ impl OutboundMessaging { debug!( target: LOG_TARGET, - "Direct message forwarding successfully completed for peer `{}`.", peer_node_id + "Direct message forwarding successfully completed for peer `{}` (stream: {}).", peer_node_id, stream_id ); Ok(()) } @@ -354,10 +341,6 @@ impl OutboundMessaging { self.messages_rx.close(); while let Some(mut out_msg) = self.messages_rx.recv().await { out_msg.reply_fail(reason); - let _ = self - .messaging_events_tx - .send(MessagingEvent::SendMessageFailed(out_msg, reason)) - .await; } } } diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index 2564eb880a1..6a391311fc1 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -39,13 +39,13 @@ use tokio_util::codec::{Framed, LengthDelimitedCodec}; use super::error::MessagingProtocolError; use crate::{ - connectivity::{ConnectivityEvent, ConnectivityRequester}, + connectivity::ConnectivityRequester, framing, message::{InboundMessage, MessageTag, OutboundMessage}, multiplexing::Substream, peer_manager::NodeId, protocol::{ - messaging::{inbound::InboundMessaging, outbound::OutboundMessaging, MessagingConfig}, + messaging::{inbound::InboundMessaging, outbound::OutboundMessaging}, ProtocolEvent, ProtocolNotification, }, @@ -90,7 +90,6 @@ pub enum SendFailReason { pub enum MessagingEvent { MessageReceived(NodeId, MessageTag), InvalidMessageReceived(NodeId), - SendMessageFailed(OutboundMessage, SendFailReason), OutboundProtocolExited(NodeId), } @@ -100,14 +99,12 @@ impl fmt::Display for MessagingEvent { match self { MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id.short_str(), tag), InvalidMessageReceived(node_id) => write!(f, "InvalidMessageReceived({})", node_id.short_str()), - SendMessageFailed(out_msg, reason) => write!(f, "SendMessageFailed({}, Reason = {})", out_msg, reason), OutboundProtocolExited(node_id) => write!(f, "OutboundProtocolExited({})", node_id), } } } pub struct MessagingProtocol { - config: MessagingConfig, connectivity: ConnectivityRequester, proto_notification: mpsc::Receiver>, active_queues: HashMap>, @@ -125,7 +122,6 @@ pub struct MessagingProtocol { impl MessagingProtocol { #[allow(clippy::too_many_arguments)] pub fn new( - config: MessagingConfig, connectivity: ConnectivityRequester, proto_notification: mpsc::Receiver>, request_rx: mpsc::Receiver, @@ -138,7 +134,6 @@ impl MessagingProtocol { let (retry_queue_tx, retry_queue_rx) = mpsc::unbounded_channel(); Self { - config, connectivity, proto_notification, request_rx, @@ -160,7 +155,6 @@ impl MessagingProtocol { pub async fn run(mut self) { let mut shutdown_signal = self.shutdown_signal.clone(); - let mut connectivity_events = self.connectivity.get_event_subscription(); loop { tokio::select! { @@ -188,12 +182,6 @@ impl MessagingProtocol { } }, - event = connectivity_events.recv() => { - if let Ok(event) = event { - self.handle_connectivity_event(&event); - } - } - Some(notification) = self.proto_notification.recv() => { self.handle_protocol_notification(notification).await; }, @@ -212,19 +200,6 @@ impl MessagingProtocol { framing::canonical(socket, MAX_FRAME_LENGTH) } - fn handle_connectivity_event(&mut self, event: &ConnectivityEvent) { - use ConnectivityEvent::*; - #[allow(clippy::single_match)] - match event { - PeerConnectionWillClose(node_id, _) => { - // If the peer connection will close, cut off the pipe to send further messages by dropping the sender. - // Any messages in the channel may be sent before the connection is disconnected. - let _ = self.active_queues.remove(node_id); - }, - _ => {}, - } - } - async fn handle_internal_messaging_event(&mut self, event: MessagingEvent) { use MessagingEvent::*; trace!(target: LOG_TARGET, "Internal messaging event '{}'", event); @@ -287,7 +262,6 @@ impl MessagingProtocol { self.connectivity.clone(), self.internal_messaging_event_tx.clone(), peer_node_id, - self.config.inactivity_timeout, self.retry_queue_tx.clone(), ); break entry.insert(sender); @@ -316,18 +290,10 @@ impl MessagingProtocol { connectivity: ConnectivityRequester, events_tx: mpsc::Sender, peer_node_id: NodeId, - inactivity_timeout: Option, retry_queue_tx: mpsc::UnboundedSender, ) -> mpsc::UnboundedSender { let (msg_tx, msg_rx) = mpsc::unbounded_channel(); - let outbound_messaging = OutboundMessaging::new( - connectivity, - events_tx, - msg_rx, - retry_queue_tx, - peer_node_id, - inactivity_timeout, - ); + let outbound_messaging = OutboundMessaging::new(connectivity, events_tx, msg_rx, retry_queue_tx, peer_node_id); task::spawn(outbound_messaging.run()); msg_tx } @@ -341,7 +307,6 @@ impl MessagingProtocol { messaging_events_tx, RATE_LIMIT_CAPACITY, RATE_LIMIT_RESTOCK_INTERVAL, - self.config.inactivity_timeout, ); task::spawn(inbound_messaging.run(substream)); } diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index 771fd4305bc..d65fc2992d5 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -20,14 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{io, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use bytes::Bytes; use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use rand::rngs::OsRng; use tari_crypto::keys::PublicKey; use tari_shutdown::Shutdown; -use tari_test_utils::{collect_recv, collect_stream, unpack_enum}; +use tari_test_utils::{collect_stream, unpack_enum}; use tokio::{ sync::{broadcast, mpsc, oneshot}, time, @@ -41,16 +41,11 @@ use super::protocol::{ MESSAGING_PROTOCOL, }; use crate::{ - memsocket::MemorySocket, message::{InboundMessage, MessageTag, MessagingReplyRx, OutboundMessage}, multiplexing::Substream, net_address::MultiaddressesWithStats, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, - protocol::{ - messaging::{inbound::InboundMessaging, SendFailReason}, - ProtocolEvent, - ProtocolNotification, - }, + protocol::{messaging::SendFailReason, ProtocolEvent, ProtocolNotification}, runtime, runtime::task, test_utils::{ @@ -88,7 +83,6 @@ async fn spawn_messaging_protocol() -> ( let (events_tx, events_rx) = broadcast::channel(100); let msg_proto = MessagingProtocol::new( - Default::default(), requester, proto_rx, request_rx, @@ -196,15 +190,15 @@ async fn send_message_dial_failed() { let (_, _, conn_manager_mock, _, request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); - let out_msg = OutboundMessage::new(node_id, TEST_MSG1.clone()); - let expected_out_msg_tag = out_msg.tag; + let (reply_tx, reply_rx) = oneshot::channel(); + let out_msg = OutboundMessage::with_reply(node_id, TEST_MSG1.clone(), reply_tx.into()); // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); let event = event_tx.recv().await.unwrap(); - unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); - unpack_enum!(SendFailReason::PeerDialFailed = reason); - assert_eq!(out_msg.tag, expected_out_msg_tag); + unpack_enum!(MessagingEvent::OutboundProtocolExited(_node_id) = &*event); + let reply = reply_rx.await.unwrap().unwrap_err(); + unpack_enum!(SendFailReason::PeerDialFailed = reply); let calls = conn_manager_mock.take_calls().await; assert_eq!(calls.len(), 2); @@ -250,12 +244,25 @@ async fn send_message_substream_bulk_failure() { expected_out_msg_tags.push(send_msg(&mut request_tx, peer_node_id.clone()).await); } - // Expect all messages to have been buffered for sending - even if they never arrive because the sender suddenly - // disconnected. + // Expect some messages to fail sending because the sender suddenly disconnected and could not be redialled. + // Others may pass due to the race between detecting disconnection and sending + let mut num_sent = 0usize; + let mut num_failed = 0usize; for (_, reply) in expected_out_msg_tags { - reply.await.unwrap().unwrap(); + match reply.await.unwrap() { + Ok(_) => { + num_sent += 1; + }, + Err(SendFailReason::PeerDialFailed) => { + num_failed += 1; + }, + Err(err) => unreachable!("Unexpected error {}", err), + } } + assert!(num_failed > 0); + assert_eq!(num_sent + num_failed, NUM_MSGS); + // Check that the outbound handler closed let event = time::timeout(Duration::from_secs(10), events_rx.recv()) .await @@ -316,7 +323,7 @@ async fn many_concurrent_send_message_requests() { #[runtime::test] async fn many_concurrent_send_message_requests_that_fail() { const NUM_MSGS: usize = 100; - let (_, _, _, _, request_tx, _, mut events_rx, _shutdown) = spawn_messaging_protocol().await; + let (_, _, _, _, request_tx, _, _, _shutdown) = spawn_messaging_protocol().await; let node_id2 = node_id::random(); @@ -336,55 +343,7 @@ async fn many_concurrent_send_message_requests_that_fail() { request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); } - // Check that we got message success events - let events = collect_recv!(events_rx, take = NUM_MSGS, timeout = Duration::from_secs(10)); - assert_eq!(events.len(), NUM_MSGS); - for event in events { - unpack_enum!(MessagingEvent::SendMessageFailed(out_msg, reason) = &*event); - unpack_enum!(SendFailReason::PeerDialFailed = reason); - // Assert that each tag is emitted only once - let index = msg_tags.iter().position(|t| t == &out_msg.tag).unwrap(); - msg_tags.remove(index); - } - let unordered = reply_rxs.into_iter().collect::>(); let results = unordered.collect::>().await; assert!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_err())); - - assert_eq!(msg_tags.len(), 0); -} - -#[runtime::test] -async fn inactivity_timeout() { - let node_identity = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); - let (inbound_msg_tx, mut inbound_msg_rx) = mpsc::channel(5); - let (events_tx, _) = broadcast::channel(1); - - let (socket_in, socket_out) = MemorySocket::new_pair(); - - task::spawn( - InboundMessaging::new( - node_identity.node_id().clone(), - inbound_msg_tx, - events_tx, - 10, - Duration::from_millis(100), - Some(Duration::from_millis(5)), - ) - .run(socket_in), - ); - - // Write messages for 5 milliseconds - let mut framed = MessagingProtocol::framed(socket_out); - for _ in 0..5u8 { - framed.send(Bytes::from_static(b"some message")).await.unwrap(); - time::sleep(Duration::from_millis(1)).await; - } - - time::sleep(Duration::from_millis(10)).await; - - let err = framed.send(Bytes::from_static(b"another message")).await.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - - let _ = collect_recv!(inbound_msg_rx, take = 5, timeout = Duration::from_secs(10)); }