Skip to content

Commit

Permalink
fix(s2n-quic-dc): use wake_forced for worker::Waker (#2415)
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft authored Dec 11, 2024
1 parent 90f3956 commit 7e37e77
Show file tree
Hide file tree
Showing 10 changed files with 418 additions and 148 deletions.
5 changes: 4 additions & 1 deletion dc/s2n-quic-dc-benches/src/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ fn pair(
accept_flavor: accept::Flavor,
) -> (stream::testing::Client, stream::testing::Server) {
let client = stream::testing::Client::default();
let server = stream::testing::Server::new(protocol, accept_flavor);
let server = stream::testing::Server::builder()
.protocol(protocol)
.accept_flavor(accept_flavor)
.build();
client.handshake_with(&server).unwrap();
(client, server)
}
Expand Down
2 changes: 2 additions & 0 deletions dc/s2n-quic-dc/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub mod socket;

#[cfg(any(test, feature = "testing"))]
pub mod testing;
#[cfg(test)]
mod tests;

bitflags::bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down
2 changes: 1 addition & 1 deletion dc/s2n-quic-dc/src/stream/send/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tracing::Instrument as _;

fn pair(protocol: Protocol) -> (testing::Client, testing::Server) {
let client = testing::Client::default();
let server = testing::Server::new(protocol, Default::default());
let server = testing::Server::builder().protocol(protocol).build();
(client, server)
}

Expand Down
2 changes: 1 addition & 1 deletion dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ where
let mut context = worker::Context::new(&self);

poll_fn(move |cx| {
workers.update_task_context(cx);
workers.poll_start(cx);

let now = self.env.clock().get_time();
let publisher = publisher(&self.subscriber, &now);
Expand Down
31 changes: 11 additions & 20 deletions dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ where
{
inner: Inner<W>,
waker_set: waker::Set,
root_waker: Option<Waker>,
}

/// Split the tasks from the waker set to avoid ownership issues
Expand Down Expand Up @@ -120,11 +119,7 @@ where
sojourn_time: RttEstimator::new(Duration::from_secs(30)),
};

Self {
inner,
waker_set,
root_waker: None,
}
Self { inner, waker_set }
}

#[inline]
Expand All @@ -149,19 +144,8 @@ where

/// Must be called before polling any workers
#[inline]
pub fn update_task_context(&mut self, cx: &mut task::Context) {
let new_waker = cx.waker();

let root_task_requires_update = if let Some(waker) = self.root_waker.as_ref() {
!waker.will_wake(new_waker)
} else {
true
};

if root_task_requires_update {
self.waker_set.update_root(new_waker);
self.root_waker = Some(new_waker.clone());
}
pub fn poll_start(&mut self, cx: &mut task::Context) {
self.waker_set.poll_start(cx);
}

#[inline]
Expand Down Expand Up @@ -221,8 +205,15 @@ where
Pub: EndpointPublisher,
C: Clock,
{
let ready = self.waker_set.drain();

// no need to actually poll any workers if none are active
if self.inner.by_sojourn_time.is_empty() {
return ControlFlow::Continue(());
}

// poll any workers that are ready
for idx in self.waker_set.drain() {
for idx in ready {
if self.inner.poll_worker(idx, cx, publisher, clock).is_break() {
return ControlFlow::Break(());
}
Expand Down
Loading

0 comments on commit 7e37e77

Please sign in to comment.