Skip to content

Commit

Permalink
apply suggested fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
leonzchang committed Oct 11, 2023
1 parent 3d8521a commit b98be2b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 82 deletions.
136 changes: 72 additions & 64 deletions swarm/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::handler::{
};
use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
use crate::{
ConnectionHandlerEvent, KeepAlive, Stream, StreamCounter, StreamProtocol, StreamUpgradeError,
ConnectionHandlerEvent, KeepAlive, Stream, StreamProtocol, StreamUpgradeError,
SubstreamProtocol,
};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -64,6 +64,9 @@ use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};

static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);

/// Counter of the number of active streams on a connection
type ActiveStreamCounter = Arc<()>;

/// Connection identifier.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConnectionId(usize);
Expand Down Expand Up @@ -158,7 +161,8 @@ where
local_supported_protocols: HashSet<StreamProtocol>,
remote_supported_protocols: HashSet<StreamProtocol>,
idle_timeout: Duration,
stream_counter: Arc<()>,
/// The counter of active streams
stream_counter: ActiveStreamCounter,
}

impl<THandler> fmt::Debug for Connection<THandler>
Expand Down Expand Up @@ -348,70 +352,30 @@ where
}
}

// Ask the handler whether it wants the connection (and the handler itself)
// to be kept alive, which determines the planned shutdown, if any.
let active_stream_count = Arc::strong_count(stream_counter);
if active_stream_count == 1 {
let keep_alive = handler.connection_keep_alive();
match (&mut *shutdown, keep_alive) {
(Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => {
if *deadline != t {
*deadline = t;
if let Some(new_duration) =
deadline.checked_duration_since(Instant::now())
{
let effective_keep_alive = max(new_duration, *idle_timeout);

timer.reset(effective_keep_alive)
}
}
}
(_, KeepAlive::Until(earliest_shutdown)) => {
let now = Instant::now();

if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
let effective_keep_alive = max(requested, *idle_timeout);

let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);

// Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch.
// This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See <https://github.com/libp2p/rust-libp2p/issues/3844>/
*shutdown =
Shutdown::Later(Delay::new(safe_keep_alive), earliest_shutdown)
}
}
(_, KeepAlive::No) if idle_timeout == &Duration::ZERO => {
*shutdown = Shutdown::Asap;
}
(Shutdown::Later(_, _), KeepAlive::No) => {
// Do nothing, i.e. let the shutdown timer continue to tick.
}
(_, KeepAlive::No) => {
let now = Instant::now();
let safe_keep_alive = checked_add_fraction(now, *idle_timeout);

*shutdown =
Shutdown::Later(Delay::new(safe_keep_alive), now + safe_keep_alive);
}
(_, KeepAlive::Yes) => *shutdown = Shutdown::None,
};
}

// Check if the connection (and handler) should be shut down.
// As long as we're still negotiating substreams, shutdown is always postponed.
if negotiating_in.is_empty()
&& negotiating_out.is_empty()
&& requested_substreams.is_empty()
{
match shutdown {
Shutdown::None => {}
Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) {
Poll::Ready(_) => {
// Ask the handler whether it wants the connection (and the handler itself)
// to be kept alive, which determines the planned shutdown, if any.
handle_should_keep_alive(handler, shutdown, idle_timeout);

// Check if the connection (and handler) should be shut down.
// As long as we're still negotiating substreams, shutdown is always postponed.
if negotiating_in.is_empty()
&& negotiating_out.is_empty()
&& requested_substreams.is_empty()
{
match shutdown {
Shutdown::None => {}
Shutdown::Asap => {
return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
}
Poll::Pending => {}
},
Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) {
Poll::Ready(_) => {
return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
}
Poll::Pending => {}
},
}
}
}

Expand Down Expand Up @@ -496,6 +460,52 @@ fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<Strea
.collect()
}

fn handle_should_keep_alive(
handler: &impl ConnectionHandler,
shutdown: &mut Shutdown,
idle_timeout: &mut Duration,
) {
let keep_alive = handler.connection_keep_alive();
match (&mut *shutdown, keep_alive) {
(Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => {
if *deadline != t {
*deadline = t;
if let Some(new_duration) = deadline.checked_duration_since(Instant::now()) {
let effective_keep_alive = max(new_duration, *idle_timeout);

timer.reset(effective_keep_alive)
}
}
}
(_, KeepAlive::Until(earliest_shutdown)) => {
let now = Instant::now();

if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
let effective_keep_alive = max(requested, *idle_timeout);

let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);

// Important: We store the _original_ `Instant` given by the `ConnectionHandler` in the `Later` instance to ensure we can compare it in the above branch.
// This is quite subtle but will hopefully become simpler soon once `KeepAlive::Until` is fully deprecated. See <https://github.com/libp2p/rust-libp2p/issues/3844>/
*shutdown = Shutdown::Later(Delay::new(safe_keep_alive), earliest_shutdown)
}
}
(_, KeepAlive::No) if idle_timeout == &Duration::ZERO => {
*shutdown = Shutdown::Asap;
}
(Shutdown::Later(_, _), KeepAlive::No) => {
// Do nothing, i.e. let the shutdown timer continue to tick.
}
(_, KeepAlive::No) => {
let now = Instant::now();
let safe_keep_alive = checked_add_fraction(now, *idle_timeout);

*shutdown = Shutdown::Later(Delay::new(safe_keep_alive), now + safe_keep_alive);
}
(_, KeepAlive::Yes) => *shutdown = Shutdown::None,
};
}

/// Repeatedly halves and adds the [`Duration`] to the [`Instant`] until [`Instant::checked_add`] succeeds.
///
/// [`Instant`] depends on the underlying platform and has a limit of which points in time it can represent.
Expand Down Expand Up @@ -572,7 +582,6 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
)
.await
.map_err(to_stream_upgrade_error)?;
let counter = StreamCounter::Arc(counter);

let output = upgrade
.upgrade_outbound(Stream::new(stream, counter), info)
Expand Down Expand Up @@ -606,7 +615,6 @@ impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
multistream_select::listener_select_proto(substream, protocols)
.await
.map_err(to_stream_upgrade_error)?;
let counter = StreamCounter::Arc(counter);

let output = upgrade
.upgrade_inbound(Stream::new(stream, counter), info)
Expand Down
9 changes: 5 additions & 4 deletions swarm/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ pub trait ConnectionHandler: Send + 'static {

/// Returns until when the connection should be kept alive.
///
/// This method is called by the `Swarm` after each invocation of
/// [`ConnectionHandler::poll`] to determine if the connection and the associated
/// [`ConnectionHandler`]s should be kept alive as far as this handler is concerned
/// and if so, for how long.
/// `Swarm` checks if there are still active streams on this connection after
/// each invocation of [`ConnectionHandler::poll`]. If no, this method will
/// be called by the `Swarm` to determine if the connection and the associated
/// [`ConnectionHandler`]s should be kept alive as far as this handler is
/// concerned and if so, for how long.
///
/// Returning [`KeepAlive::No`] indicates that the connection should be
/// closed and this handler destroyed immediately.
Expand Down
2 changes: 1 addition & 1 deletion swarm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub use handler::{
#[cfg(feature = "macros")]
pub use libp2p_swarm_derive::NetworkBehaviour;
pub use listen_opts::ListenOpts;
pub use stream::{Stream, StreamCounter};
pub use stream::Stream;
pub use stream_protocol::{InvalidProtocol, StreamProtocol};

use crate::behaviour::ExternalAddrConfirmed;
Expand Down
23 changes: 10 additions & 13 deletions swarm/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,27 @@ use std::{
#[derive(Debug)]
pub struct Stream {
stream: Negotiated<SubstreamBox>,
stream_counter: StreamCounter,
counter: StreamCounter,
}

#[derive(Debug)]
pub enum StreamCounter {
enum StreamCounter {
Arc(Arc<()>),
Weak(Weak<()>),
}

impl Stream {
pub(crate) fn new(stream: Negotiated<SubstreamBox>, stream_counter: StreamCounter) -> Self {
Self {
stream,
stream_counter,
}
pub(crate) fn new(stream: Negotiated<SubstreamBox>, counter: Arc<()>) -> Self {
let counter = StreamCounter::Arc(counter);
Self { stream, counter }
}

/// downgrade the Arc<()> to a Weak<()> which automatically
/// reduces the strong_count in Connection's stream_counter
pub fn no_keep_alive(&mut self) {
let stream_counter = match &self.stream_counter {
StreamCounter::Arc(arc_counter) => StreamCounter::Weak(Arc::downgrade(arc_counter)),
StreamCounter::Weak(weak_counter) => StreamCounter::Weak(weak_counter.clone()),
};

self.stream_counter = stream_counter;
if let StreamCounter::Arc(arc_counter) = &self.counter {
self.counter = StreamCounter::Weak(Arc::downgrade(arc_counter));
}
}
}

Expand Down

0 comments on commit b98be2b

Please sign in to comment.