From 7e37e7748c9e9a6f0b457acf39aaa0c48e1a3bc0 Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Wed, 11 Dec 2024 14:50:21 -0700 Subject: [PATCH] fix(s2n-quic-dc): use wake_forced for worker::Waker (#2415) --- dc/s2n-quic-dc-benches/src/streams.rs | 5 +- dc/s2n-quic-dc/src/stream.rs | 2 + dc/s2n-quic-dc/src/stream/send/tests.rs | 2 +- dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs | 2 +- .../src/stream/server/tokio/tcp/manager.rs | 31 +- dc/s2n-quic-dc/src/stream/testing.rs | 311 +++++++++++------- dc/s2n-quic-dc/src/stream/tests.rs | 4 + .../src/stream/tests/accept_queue.rs | 179 ++++++++++ dc/s2n-quic-dc/src/task/waker/set.rs | 25 +- dc/s2n-quic-dc/src/task/waker/worker.rs | 5 + 10 files changed, 418 insertions(+), 148 deletions(-) create mode 100644 dc/s2n-quic-dc/src/stream/tests.rs create mode 100644 dc/s2n-quic-dc/src/stream/tests/accept_queue.rs diff --git a/dc/s2n-quic-dc-benches/src/streams.rs b/dc/s2n-quic-dc-benches/src/streams.rs index 69ae425117..3cd0c21864 100644 --- a/dc/s2n-quic-dc-benches/src/streams.rs +++ b/dc/s2n-quic-dc-benches/src/streams.rs @@ -37,7 +37,10 @@ fn pair( accept_flavor: accept::Flavor, ) -> (stream::testing::Client, stream::testing::Server) { let client = stream::testing::Client::default(); - let server = stream::testing::Server::new(protocol, accept_flavor); + let server = stream::testing::Server::builder() + .protocol(protocol) + .accept_flavor(accept_flavor) + .build(); client.handshake_with(&server).unwrap(); (client, server) } diff --git a/dc/s2n-quic-dc/src/stream.rs b/dc/s2n-quic-dc/src/stream.rs index 2ad810d6e4..a110eb2e10 100644 --- a/dc/s2n-quic-dc/src/stream.rs +++ b/dc/s2n-quic-dc/src/stream.rs @@ -28,6 +28,8 @@ pub mod socket; #[cfg(any(test, feature = "testing"))] pub mod testing; +#[cfg(test)] +mod tests; bitflags::bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/dc/s2n-quic-dc/src/stream/send/tests.rs b/dc/s2n-quic-dc/src/stream/send/tests.rs index 4d7ab31f3a..05270272fe 100644 --- a/dc/s2n-quic-dc/src/stream/send/tests.rs +++ b/dc/s2n-quic-dc/src/stream/send/tests.rs @@ -9,7 +9,7 @@ use tracing::Instrument as _; fn pair(protocol: Protocol) -> (testing::Client, testing::Server) { let client = testing::Client::default(); - let server = testing::Server::new(protocol, Default::default()); + let server = testing::Server::builder().protocol(protocol).build(); (client, server) } diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs index c6ba47dc77..72a968281b 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs @@ -79,7 +79,7 @@ where let mut context = worker::Context::new(&self); poll_fn(move |cx| { - workers.update_task_context(cx); + workers.poll_start(cx); let now = self.env.clock().get_time(); let publisher = publisher(&self.subscriber, &now); diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs index 44b1bfb956..adbaa24085 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/manager.rs @@ -31,7 +31,6 @@ where { inner: Inner, waker_set: waker::Set, - root_waker: Option, } /// Split the tasks from the waker set to avoid ownership issues @@ -120,11 +119,7 @@ where sojourn_time: RttEstimator::new(Duration::from_secs(30)), }; - Self { - inner, - waker_set, - root_waker: None, - } + Self { inner, waker_set } } #[inline] @@ -149,19 +144,8 @@ where /// Must be called before polling any workers #[inline] - pub fn update_task_context(&mut self, cx: &mut task::Context) { - let new_waker = cx.waker(); - - let root_task_requires_update = if let Some(waker) = self.root_waker.as_ref() { - !waker.will_wake(new_waker) - } else { - true - }; - - if root_task_requires_update { - self.waker_set.update_root(new_waker); - self.root_waker = Some(new_waker.clone()); - } + pub fn poll_start(&mut self, cx: &mut task::Context) { + self.waker_set.poll_start(cx); } #[inline] @@ -221,8 +205,15 @@ where Pub: EndpointPublisher, C: Clock, { + let ready = self.waker_set.drain(); + + // no need to actually poll any workers if none are active + if self.inner.by_sojourn_time.is_empty() { + return ControlFlow::Continue(()); + } + // poll any workers that are ready - for idx in self.waker_set.drain() { + for idx in ready { if self.inner.poll_worker(idx, cx, publisher, clock).is_break() { return ControlFlow::Break(()); } diff --git a/dc/s2n-quic-dc/src/stream/testing.rs b/dc/s2n-quic-dc/src/stream/testing.rs index 9100c65980..d6d0a53d9c 100644 --- a/dc/s2n-quic-dc/src/stream/testing.rs +++ b/dc/s2n-quic-dc/src/stream/testing.rs @@ -7,9 +7,9 @@ use crate::{ path::secret, stream::{ application::Stream, - client::tokio as client, + client::tokio as stream_client, environment::{tokio as env, Environment as _}, - server::tokio::{self as server, accept}, + server::{tokio as stream_server, tokio::accept}, }, }; use std::{io, net::SocketAddr}; @@ -34,7 +34,7 @@ impl Default for Client { } impl Client { - pub fn handshake_with>( + pub fn handshake_with>( &self, server: &S, ) -> io::Result { @@ -54,7 +54,7 @@ impl Client { }) } - pub async fn connect_to>( + pub async fn connect_to>( &self, server: &S, ) -> io::Result> { @@ -65,33 +65,35 @@ impl Client { match server.protocol { Protocol::Tcp => { - client::connect_tcp(handshake, server.local_addr, &self.env, subscriber).await + stream_client::connect_tcp(handshake, server.local_addr, &self.env, subscriber) + .await } Protocol::Udp => { - client::connect_udp(handshake, server.local_addr, &self.env, subscriber).await + stream_client::connect_udp(handshake, server.local_addr, &self.env, subscriber) + .await } Protocol::Other(name) => { todo!("protocol {name:?} not implemented") } } } -} -#[derive(Clone)] -pub struct ServerHandle { - map: secret::Map, - protocol: Protocol, - local_addr: SocketAddr, -} + pub async fn connect_tcp_with>( + &self, + server: &S, + stream: tokio::net::TcpStream, + ) -> io::Result> { + let server = server.as_ref(); + let handshake = async { self.handshake_with(server) }.await?; + + let subscriber = Subscriber::default(); -impl AsRef for ServerHandle { - fn as_ref(&self) -> &ServerHandle { - self + stream_client::connect_tcp_with(handshake, stream, &self.env, subscriber).await } } pub struct Server { - handle: ServerHandle, + handle: server::Handle, receiver: accept::Receiver, stats: stats::Sender, #[allow(dead_code)] @@ -100,125 +102,34 @@ pub struct Server { impl Default for Server { fn default() -> Self { - Self::new_udp(accept::Flavor::Fifo) + Self::tcp().build() } } -impl AsRef for Server { - fn as_ref(&self) -> &ServerHandle { +impl AsRef for Server { + fn as_ref(&self) -> &server::Handle { &self.handle } } impl Server { - pub fn new_tcp(accept_flavor: accept::Flavor) -> Self { - Self::new(Protocol::Tcp, accept_flavor) + pub fn builder() -> server::Builder { + server::Builder::default() } - pub fn new_udp(accept_flavor: accept::Flavor) -> Self { - Self::new(Protocol::Udp, accept_flavor) + pub fn tcp() -> server::Builder { + Self::builder().tcp() } - pub fn new(protocol: Protocol, accept_flavor: accept::Flavor) -> Self { - if s2n_quic_platform::io::testing::is_in_env() { - todo!() - } else { - Self::new_tokio(protocol, accept_flavor) - } + pub fn udp() -> server::Builder { + Self::builder().udp() } - fn new_tokio(protocol: Protocol, accept_flavor: accept::Flavor) -> Self { - let _span = tracing::info_span!("server").entered(); - let map = secret::map::testing::new(16); - let (sender, receiver) = accept::channel(16); - - let options = crate::socket::Options::new("127.0.0.1:0".parse().unwrap()); - - let env = env::Builder::default().build().unwrap(); - - let subscriber = event::tracing::Subscriber::default(); - let (drop_handle_sender, drop_handle_receiver) = drop_handle::new(); - - let local_addr = match protocol { - Protocol::Tcp => { - let socket = options.build_tcp_listener().unwrap(); - let local_addr = socket.local_addr().unwrap(); - let socket = tokio::net::TcpListener::from_std(socket).unwrap(); - - let acceptor = server::tcp::Acceptor::new( - 0, - socket, - &sender, - &env, - &map, - 16, - accept_flavor, - subscriber, - ); - let acceptor = drop_handle_receiver.wrap(acceptor.run()); - let acceptor = acceptor.instrument(tracing::info_span!("tcp")); - tokio::task::spawn(acceptor); - - local_addr - } - Protocol::Udp => { - let socket = options.build_udp().unwrap(); - let local_addr = socket.local_addr().unwrap(); - - let socket = tokio::io::unix::AsyncFd::new(socket).unwrap(); - - let acceptor = server::udp::Acceptor::new( - 0, - socket, - &sender, - &env, - &map, - accept_flavor, - subscriber, - ); - let acceptor = drop_handle_receiver.wrap(acceptor.run()); - let acceptor = acceptor.instrument(tracing::info_span!("udp")); - tokio::task::spawn(acceptor); - - local_addr - } - Protocol::Other(name) => { - todo!("protocol {name:?} not implemented") - } - }; - - let (stats_sender, stats_worker, stats) = stats::channel(); - - { - let task = stats_worker.run(env.clock().clone()); - let task = task.instrument(tracing::info_span!("stats")); - let task = drop_handle_receiver.wrap(task); - tokio::task::spawn(task); - } - - if matches!(accept_flavor, accept::Flavor::Lifo) { - let channel = receiver.downgrade(); - let task = accept::Pruner::default().run(env, channel, stats); - let task = task.instrument(tracing::info_span!("pruner")); - let task = drop_handle_receiver.wrap(task); - tokio::task::spawn(task); - } - - let handle = ServerHandle { - map, - protocol, - local_addr, - }; - - Self { - handle, - receiver, - stats: stats_sender, - drop_handle: drop_handle_sender, - } + pub fn local_addr(&self) -> SocketAddr { + self.as_ref().local_addr } - pub fn handle(&self) -> ServerHandle { + pub fn handle(&self) -> server::Handle { self.handle.clone() } @@ -256,3 +167,163 @@ mod drop_handle { pub struct Sender(#[allow(dead_code)] watch::Sender<()>); } + +pub mod server { + use super::*; + + #[derive(Clone)] + pub struct Handle { + pub(super) map: secret::Map, + pub(super) protocol: Protocol, + pub(super) local_addr: SocketAddr, + } + + impl AsRef for Handle { + fn as_ref(&self) -> &Handle { + self + } + } + + pub struct Builder { + backlog: usize, + flavor: accept::Flavor, + protocol: Protocol, + map_capacity: usize, + } + + impl Default for Builder { + fn default() -> Self { + Self { + backlog: 16, + flavor: accept::Flavor::default(), + protocol: Protocol::Tcp, + map_capacity: 16, + } + } + } + + impl Builder { + pub fn build(self) -> Server { + if s2n_quic_platform::io::testing::is_in_env() { + todo!() + } else { + self.build_tokio() + } + } + + pub fn tcp(mut self) -> Self { + self.protocol = Protocol::Tcp; + self + } + + pub fn udp(mut self) -> Self { + self.protocol = Protocol::Udp; + self + } + + pub fn protocol(mut self, protocol: Protocol) -> Self { + self.protocol = protocol; + self + } + + pub fn backlog(mut self, backlog: usize) -> Self { + self.backlog = backlog; + self + } + + pub fn map_capacity(mut self, map_capacity: usize) -> Self { + self.map_capacity = map_capacity; + self + } + + pub fn accept_flavor(mut self, flavor: accept::Flavor) -> Self { + self.flavor = flavor; + self + } + + fn build_tokio(self) -> super::Server { + let Self { + backlog, + flavor, + protocol, + map_capacity, + } = self; + + let _span = tracing::info_span!("server").entered(); + let map = secret::map::testing::new(map_capacity); + let (sender, receiver) = accept::channel(backlog); + + let options = crate::socket::Options::new("127.0.0.1:0".parse().unwrap()); + + let env = env::Builder::default().build().unwrap(); + + let subscriber = event::tracing::Subscriber::default(); + let (drop_handle_sender, drop_handle_receiver) = drop_handle::new(); + + let local_addr = match protocol { + Protocol::Tcp => { + let socket = options.build_tcp_listener().unwrap(); + let local_addr = socket.local_addr().unwrap(); + let socket = tokio::net::TcpListener::from_std(socket).unwrap(); + + let acceptor = stream_server::tcp::Acceptor::new( + 0, socket, &sender, &env, &map, backlog, flavor, subscriber, + ); + let acceptor = drop_handle_receiver.wrap(acceptor.run()); + let acceptor = acceptor.instrument(tracing::info_span!("tcp")); + tokio::task::spawn(acceptor); + + local_addr + } + Protocol::Udp => { + let socket = options.build_udp().unwrap(); + let local_addr = socket.local_addr().unwrap(); + + let socket = tokio::io::unix::AsyncFd::new(socket).unwrap(); + + let acceptor = stream_server::udp::Acceptor::new( + 0, socket, &sender, &env, &map, flavor, subscriber, + ); + let acceptor = drop_handle_receiver.wrap(acceptor.run()); + let acceptor = acceptor.instrument(tracing::info_span!("udp")); + tokio::task::spawn(acceptor); + + local_addr + } + Protocol::Other(name) => { + todo!("protocol {name:?} not implemented") + } + }; + + let (stats_sender, stats_worker, stats) = stats::channel(); + + { + let task = stats_worker.run(env.clock().clone()); + let task = task.instrument(tracing::info_span!("stats")); + let task = drop_handle_receiver.wrap(task); + tokio::task::spawn(task); + } + + if matches!(flavor, accept::Flavor::Lifo) { + let channel = receiver.downgrade(); + let task = accept::Pruner::default().run(env, channel, stats); + let task = task.instrument(tracing::info_span!("pruner")); + let task = drop_handle_receiver.wrap(task); + tokio::task::spawn(task); + } + + let handle = server::Handle { + map, + protocol, + local_addr, + }; + + super::Server { + handle, + receiver, + stats: stats_sender, + drop_handle: drop_handle_sender, + } + } + } +} diff --git a/dc/s2n-quic-dc/src/stream/tests.rs b/dc/s2n-quic-dc/src/stream/tests.rs new file mode 100644 index 0000000000..32239b8055 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/tests.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod accept_queue; diff --git a/dc/s2n-quic-dc/src/stream/tests/accept_queue.rs b/dc/s2n-quic-dc/src/stream/tests/accept_queue.rs new file mode 100644 index 0000000000..aee351c958 --- /dev/null +++ b/dc/s2n-quic-dc/src/stream/tests/accept_queue.rs @@ -0,0 +1,179 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + stream::testing::{Client, Server}, + testing::init_tracing, +}; +use std::{io, time::Duration}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tracing::{info_span, Instrument}; + +async fn check_stream(client: &Client, server: &Server) -> io::Result<()> { + tokio::try_join!( + async { + let mut a = client.connect_to(server).await?; + let _ = a.write_all(b"testing").await; + + // wait some time before calling shutdown in case the server reset the connection so we + // can observe it in `shutdown` + tokio::time::sleep(Duration::from_millis(10)).await; + + let _ = a.shutdown().await; + + let mut buffer = vec![]; + a.read_to_end(&mut buffer).await?; + assert_eq!(buffer, b"testing"); + Ok(()) + } + .instrument(info_span!("client")), + async { + let (mut b, _) = server.accept().await.expect("accept"); + let mut buffer = vec![]; + b.read_to_end(&mut buffer).await.unwrap(); + assert_eq!(buffer, b"testing"); + + b.write_all(&buffer).await.unwrap(); + b.shutdown().await.unwrap(); + + Ok(()) + } + .instrument(info_span!("server")) + ) + .map(|_| ()) +} + +#[tokio::test] +async fn failed_packet() { + init_tracing(); + + let client = Client::default(); + let server = Server::tcp().build(); + let mut stream = tokio::net::TcpStream::connect(server.local_addr()) + .await + .unwrap(); + // First write succeeds. + stream + .write_all(b"this is not a dcQUIC message") + .await + .unwrap(); + // Note: We do *not* shutdown the stream here, we expect the server to end the stream on its + // side since we wrote bad data. + let mut err = vec![]; + let kind = stream + .read_to_end(&mut err) + .await + .expect_err("the server should reset the connection") + .kind(); + assert_eq!(kind, io::ErrorKind::ConnectionReset); + // We currently silently drop malformed streams, ending them with an EOF. + assert_eq!(err.len(), 0); + + // Confirm subsequent streams connect successfully. + check_stream(&client, &server).await.unwrap(); +} + +#[tokio::test] +async fn immediate_eof() { + init_tracing(); + + let client = Client::default(); + let server = Server::tcp().build(); + let mut stream = tokio::net::TcpStream::connect(server.local_addr()) + .await + .unwrap(); + // Immediately end the stream without any data being sent. + stream.shutdown().await.unwrap(); + let mut err = vec![]; + let kind = stream + .read_to_end(&mut err) + .await + .expect_err("the server should reset the connection") + .kind(); + assert_eq!(kind, io::ErrorKind::ConnectionReset); + // We currently silently drop malformed streams, ending them with an EOF. + assert_eq!(err.len(), 0); + + // Confirm subsequent streams connect successfully. + check_stream(&client, &server).await.unwrap(); +} + +// Confirm that we can use all of the concurrency for streams that have not yet sent a prelude. +#[tokio::test] +async fn within_concurrency() { + init_tracing(); + + let client = Client::default(); + let concurrent = 300; + let server = Server::tcp().backlog(concurrent).build(); + + client.handshake_with(&server).unwrap(); + + let mut pending_streams = vec![]; + for _ in 0..concurrent { + let stream = tokio::net::TcpStream::connect(server.local_addr()) + .await + .unwrap(); + pending_streams.push(stream); + } + for stream in pending_streams { + // Effectively this just writes the prelude. + let mut stream = client.connect_tcp_with(&server, stream).await.unwrap(); + // Confirm stream actually opened.. + stream.write_from(&mut &[0x3u8; 100][..]).await.unwrap(); + } +} + +// Exercise dropping connections when we go over the allowed concurrency. +#[tokio::test] +async fn graceful_surpassing_concurrency() { + init_tracing(); + + let client = Client::default(); + let concurrent = 5; + let server = Server::tcp().backlog(concurrent).build(); + + client.handshake_with(&server).unwrap(); + + let mut streams = vec![]; + for _ in 0..(concurrent * 2) { + let stream = tokio::net::TcpStream::connect(server.local_addr()) + .await + .unwrap(); + streams.push(stream); + tokio::task::yield_now().await; + } + + let server_handle = server.handle(); + + tokio::task::spawn(async move { + while let Ok((mut stream, _peer_addr)) = server.accept().await { + let _ = stream.write_from(&mut &b"hello"[..]).await; + let _ = stream.shutdown().await; + drop(stream); + } + }); + + // Need to give time for server to drop the streams. + tokio::time::sleep(Duration::from_secs(1)).await; + + let mut errors = 0; + let mut ok = 0; + + for stream in streams { + let mut stream = client + .connect_tcp_with(&server_handle, stream) + .await + .unwrap(); + let mut out = s2n_quic_core::buffer::writer::storage::Discard; + let res = stream.read_into(&mut out).await; + match res { + Ok(_) => ok += 1, + Err(_e) => errors += 1, + } + } + + assert_eq!(errors + ok, concurrent * 2); + assert_eq!(errors, concurrent); + assert_eq!(ok, concurrent); +} diff --git a/dc/s2n-quic-dc/src/task/waker/set.rs b/dc/s2n-quic-dc/src/task/waker/set.rs index d6cc8f7807..caa6d91795 100644 --- a/dc/s2n-quic-dc/src/task/waker/set.rs +++ b/dc/s2n-quic-dc/src/task/waker/set.rs @@ -4,7 +4,7 @@ use super::worker; use std::{ sync::{Arc, Mutex}, - task::{Wake, Waker}, + task::{self, Wake, Waker}, }; mod bitset; @@ -14,12 +14,25 @@ use bitset::BitSet; pub struct Set { state: Arc, ready: BitSet, + local_root: Option, } impl Set { - /// Updates the root waker - pub fn update_root(&self, waker: &Waker) { - self.state.root.update(waker); + /// Called at the beginning of the `poll` function for the owner of [`Set`] + #[inline] + pub fn poll_start(&mut self, cx: &task::Context) { + let new_waker = cx.waker(); + + let root_task_requires_update = if let Some(waker) = self.local_root.as_ref() { + !waker.will_wake(new_waker) + } else { + true + }; + + if root_task_requires_update { + self.state.root.update(new_waker); + self.local_root = Some(new_waker.clone()); + } } /// Registers a waker with the given ID @@ -32,6 +45,7 @@ impl Set { } /// Returns all of the IDs that are woken + #[inline] pub fn drain(&mut self) -> impl Iterator + '_ { core::mem::swap(&mut self.ready, &mut self.state.ready.lock().unwrap()); self.ready.drain() @@ -58,7 +72,8 @@ impl Wake for Slot { ready.insert_unchecked(self.id) } drop(ready); - self.state.root.wake(); + // use `wake_forced` instead of `wake` since we don't use the sleeping status from `worker::Waker`` + self.state.root.wake_forced(); } } diff --git a/dc/s2n-quic-dc/src/task/waker/worker.rs b/dc/s2n-quic-dc/src/task/waker/worker.rs index 2568b160a0..cdb639eaaf 100644 --- a/dc/s2n-quic-dc/src/task/waker/worker.rs +++ b/dc/s2n-quic-dc/src/task/waker/worker.rs @@ -49,6 +49,11 @@ impl Waker { // we only need to `wake_by_ref` if the worker is sleeping ensure!(matches!(status, Status::Sleeping)); + self.wake_forced(); + } + + #[inline] + pub fn wake_forced(&self) { let guard = crossbeam_epoch::pin(); let waker = self.waker.load(Ordering::Acquire, &guard); let Some(waker) = (unsafe { waker.as_ref() }) else {