diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 36ef265..4d888ad 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -64,42 +64,48 @@ impl TcpListener { /// established, the corresponding [`TcpStream`] and the remote peer’s /// address will be returned. pub async fn accept(&self) -> Result<(TcpStream, SocketAddr)> { - loop { + let origin = loop { let maybe_accept = World::current(|world| { let host = world.current_host_mut(); - let (syn, origin) = host.tcp.accept(self.local_addr)?; - - tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %"TCP SYN", "Recv"); + host.tcp.accept(self.local_addr) + }); - // Send SYN-ACK -> origin. If Ok we proceed (acts as the ACK), - // else we return early to avoid host mutations. - let ack = syn.ack.send(()); - tracing::trace!(target: TRACING_TARGET, src = ?self.local_addr, dst = ?origin, protocol = %"TCP SYN-ACK", "Send"); + let Some((syn, origin)) = maybe_accept else { + // Wait for a new incoming connection, then retry. + self.notify.notified().await; + continue; + }; - if ack.is_err() { - return None; - } + tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %"TCP SYN", "Recv"); - let mut my_addr = self.local_addr; - if origin.ip().is_loopback() { - my_addr.set_ip(origin.ip()); - } - if my_addr.ip().is_unspecified() { - my_addr.set_ip(host.addr); - } + // Send SYN-ACK -> origin. If Ok we proceed (acts as the ACK), else + // we retry. + let ack = syn.ack.send(()); + tracing::trace!(target: TRACING_TARGET, src = ?self.local_addr, dst = ?origin, protocol = %"TCP SYN-ACK", "Send"); - let pair = SocketPair::new(my_addr, origin); - let rx = host.tcp.new_stream(pair); + if ack.is_ok() { + break origin; + } + }; - Some((TcpStream::new(pair, rx), origin)) - }); + let stream = World::current(|world| { + let host = world.current_host_mut(); - if let Some(accepted) = maybe_accept { - return Ok(accepted); + let mut my_addr = self.local_addr; + if origin.ip().is_loopback() { + my_addr.set_ip(origin.ip()); + } + if my_addr.ip().is_unspecified() { + my_addr.set_ip(host.addr); } - self.notify.notified().await; - } + let pair = SocketPair::new(my_addr, origin); + let rx = host.tcp.new_stream(pair); + TcpStream::new(pair, rx) + }); + + tracing::trace!(target: TRACING_TARGET, src = ?self.local_addr, dst = ?origin, "Accepted"); + Ok((stream, origin)) } /// Returns the local address that this listener is bound to. diff --git a/tests/tcp.rs b/tests/tcp.rs index 8ce5aa6..81fc1de 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -292,6 +292,48 @@ fn accept_front_of_line_blocking() -> Result { sim.run() } +#[test] +fn accept_front_of_line_dropped() -> Result { + let wait = Rc::new(Notify::new()); + let notify = wait.clone(); + + let mut sim = Builder::new() + .min_message_latency(Duration::ZERO) + .max_message_latency(Duration::ZERO) + .build(); + + sim.host("server", move || { + let wait = Rc::clone(&wait); + async move { + let listener = bind().await?; + wait.notified().await; + + while let Ok((_, peer)) = listener.accept().await { + tracing::debug!("peer {}", peer); + } + + Ok(()) + } + }); + + sim.client("client", async move { + // Queue up a number of broken connections at the server. + for _ in 0..5 { + let connect = TcpStream::connect(("server", PORT)); + assert!(timeout(Duration::from_secs(1), connect).await.is_err()); + } + + // After allowing the server to accept, the next connection attempt + // should succeed. + notify.notify_one(); + let _ = TcpStream::connect(("server", PORT)).await?; + + Ok(()) + }); + + sim.run() +} + #[test] fn send_upon_accept() -> Result { let mut sim = Builder::new().build();