diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index 3f08b024467e..178a2b54af7c 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -36,7 +36,7 @@ use crate::handler::{ }; use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::{ - ConnectionHandlerEvent, KeepAlive, Stream, StreamCounter, StreamProtocol, StreamUpgradeError, + ConnectionHandlerEvent, KeepAlive, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol, }; use futures::future::BoxFuture; @@ -64,6 +64,9 @@ use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll}; static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1); +/// Counter of the number of active streams on a connection +type ActiveStreamCounter = Arc<()>; + /// Connection identifier. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ConnectionId(usize); @@ -158,7 +161,8 @@ where local_supported_protocols: HashSet, remote_supported_protocols: HashSet, idle_timeout: Duration, - stream_counter: Arc<()>, + /// The counter of active streams + stream_counter: ActiveStreamCounter, } impl fmt::Debug for Connection @@ -348,70 +352,30 @@ where } } - // Ask the handler whether it wants the connection (and the handler itself) - // to be kept alive, which determines the planned shutdown, if any. let active_stream_count = Arc::strong_count(stream_counter); if active_stream_count == 1 { - let keep_alive = handler.connection_keep_alive(); - match (&mut *shutdown, keep_alive) { - (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => { - if *deadline != t { - *deadline = t; - if let Some(new_duration) = - deadline.checked_duration_since(Instant::now()) - { - let effective_keep_alive = max(new_duration, *idle_timeout); - - timer.reset(effective_keep_alive) - } - } - } - (_, KeepAlive::Until(earliest_shutdown)) => { - let now = Instant::now(); - - if let Some(requested) = earliest_shutdown.checked_duration_since(now) { - let effective_keep_alive = max(requested, *idle_timeout); - - let safe_keep_alive = checked_add_fraction(now, effective_keep_alive); - - // Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch. - // This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See / - *shutdown = - Shutdown::Later(Delay::new(safe_keep_alive), earliest_shutdown) - } - } - (_, KeepAlive::No) if idle_timeout == &Duration::ZERO => { - *shutdown = Shutdown::Asap; - } - (Shutdown::Later(_, _), KeepAlive::No) => { - // Do nothing, i.e. let the shutdown timer continue to tick. - } - (_, KeepAlive::No) => { - let now = Instant::now(); - let safe_keep_alive = checked_add_fraction(now, *idle_timeout); - - *shutdown = - Shutdown::Later(Delay::new(safe_keep_alive), now + safe_keep_alive); - } - (_, KeepAlive::Yes) => *shutdown = Shutdown::None, - }; - } - - // Check if the connection (and handler) should be shut down. - // As long as we're still negotiating substreams, shutdown is always postponed. - if negotiating_in.is_empty() - && negotiating_out.is_empty() - && requested_substreams.is_empty() - { - match shutdown { - Shutdown::None => {} - Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)), - Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) { - Poll::Ready(_) => { + // Ask the handler whether it wants the connection (and the handler itself) + // to be kept alive, which determines the planned shutdown, if any. + handle_should_keep_alive(handler, shutdown, idle_timeout); + + // Check if the connection (and handler) should be shut down. + // As long as we're still negotiating substreams, shutdown is always postponed. + if negotiating_in.is_empty() + && negotiating_out.is_empty() + && requested_substreams.is_empty() + { + match shutdown { + Shutdown::None => {} + Shutdown::Asap => { return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)) } - Poll::Pending => {} - }, + Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) { + Poll::Ready(_) => { + return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)) + } + Poll::Pending => {} + }, + } } } @@ -496,6 +460,52 @@ fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet { + if *deadline != t { + *deadline = t; + if let Some(new_duration) = deadline.checked_duration_since(Instant::now()) { + let effective_keep_alive = max(new_duration, *idle_timeout); + + timer.reset(effective_keep_alive) + } + } + } + (_, KeepAlive::Until(earliest_shutdown)) => { + let now = Instant::now(); + + if let Some(requested) = earliest_shutdown.checked_duration_since(now) { + let effective_keep_alive = max(requested, *idle_timeout); + + let safe_keep_alive = checked_add_fraction(now, effective_keep_alive); + + // Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch. + // This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See / + *shutdown = Shutdown::Later(Delay::new(safe_keep_alive), earliest_shutdown) + } + } + (_, KeepAlive::No) if idle_timeout == &Duration::ZERO => { + *shutdown = Shutdown::Asap; + } + (Shutdown::Later(_, _), KeepAlive::No) => { + // Do nothing, i.e. let the shutdown timer continue to tick. + } + (_, KeepAlive::No) => { + let now = Instant::now(); + let safe_keep_alive = checked_add_fraction(now, *idle_timeout); + + *shutdown = Shutdown::Later(Delay::new(safe_keep_alive), now + safe_keep_alive); + } + (_, KeepAlive::Yes) => *shutdown = Shutdown::None, + }; +} + /// Repeatedly halves and adds the [`Duration`] to the [`Instant`] until [`Instant::checked_add`] succeeds. /// /// [`Instant`] depends on the underlying platform and has a limit of which points in time it can represent. @@ -572,7 +582,6 @@ impl StreamUpgrade { ) .await .map_err(to_stream_upgrade_error)?; - let counter = StreamCounter::Arc(counter); let output = upgrade .upgrade_outbound(Stream::new(stream, counter), info) @@ -606,7 +615,6 @@ impl StreamUpgrade { multistream_select::listener_select_proto(substream, protocols) .await .map_err(to_stream_upgrade_error)?; - let counter = StreamCounter::Arc(counter); let output = upgrade .upgrade_inbound(Stream::new(stream, counter), info) diff --git a/swarm/src/handler.rs b/swarm/src/handler.rs index 9374903f9b74..7b076a7b04a9 100644 --- a/swarm/src/handler.rs +++ b/swarm/src/handler.rs @@ -125,10 +125,11 @@ pub trait ConnectionHandler: Send + 'static { /// Returns until when the connection should be kept alive. /// - /// This method is called by the `Swarm` after each invocation of - /// [`ConnectionHandler::poll`] to determine if the connection and the associated - /// [`ConnectionHandler`]s should be kept alive as far as this handler is concerned - /// and if so, for how long. + /// `Swarm` checks if there are still active streams on this connection after + /// each invocation of [`ConnectionHandler::poll`]. If no, this method will + /// be called by the `Swarm` to determine if the connection and the associated + /// [`ConnectionHandler`]s should be kept alive as far as this handler is + /// concerned and if so, for how long. /// /// Returning [`KeepAlive::No`] indicates that the connection should be /// closed and this handler destroyed immediately. diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 0854a73c58b7..52498156f1f0 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -126,7 +126,7 @@ pub use handler::{ #[cfg(feature = "macros")] pub use libp2p_swarm_derive::NetworkBehaviour; pub use listen_opts::ListenOpts; -pub use stream::{Stream, StreamCounter}; +pub use stream::Stream; pub use stream_protocol::{InvalidProtocol, StreamProtocol}; use crate::behaviour::ExternalAddrConfirmed; diff --git a/swarm/src/stream.rs b/swarm/src/stream.rs index 1964831425d7..d39998fd4c72 100644 --- a/swarm/src/stream.rs +++ b/swarm/src/stream.rs @@ -11,30 +11,27 @@ use std::{ #[derive(Debug)] pub struct Stream { stream: Negotiated, - stream_counter: StreamCounter, + counter: StreamCounter, } #[derive(Debug)] -pub enum StreamCounter { +enum StreamCounter { Arc(Arc<()>), Weak(Weak<()>), } impl Stream { - pub(crate) fn new(stream: Negotiated, stream_counter: StreamCounter) -> Self { - Self { - stream, - stream_counter, - } + pub(crate) fn new(stream: Negotiated, counter: Arc<()>) -> Self { + let counter = StreamCounter::Arc(counter); + Self { stream, counter } } + /// downgrade the Arc<()> to a Weak<()> which automatically + /// reduces the strong_count in Connection's stream_counter pub fn no_keep_alive(&mut self) { - let stream_counter = match &self.stream_counter { - StreamCounter::Arc(arc_counter) => StreamCounter::Weak(Arc::downgrade(arc_counter)), - StreamCounter::Weak(weak_counter) => StreamCounter::Weak(weak_counter.clone()), - }; - - self.stream_counter = stream_counter; + if let StreamCounter::Arc(arc_counter) = &self.counter { + self.counter = StreamCounter::Weak(Arc::downgrade(arc_counter)); + } } }