diff --git a/Cargo.lock b/Cargo.lock index 0cbdfa29feb9..0d56d9a83b08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1969,9 +1969,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "2c516611246607d0c04186886dbb3a754368ef82c79e9827a802c6d836dd111c" [[package]] name = "pin-utils" @@ -2506,12 +2506,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.4.9" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" dependencies = [ "libc", - "winapi", + "windows-sys", ] [[package]] @@ -2747,11 +2747,10 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.29.1" +version = "1.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", diff --git a/Cargo.toml b/Cargo.toml index 5ccd378d7af3..cb62cc293516 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -244,7 +244,7 @@ tempfile = "3.1.0" filecheck = "0.5.0" libc = "0.2.60" file-per-thread-logger = "0.2.0" -tokio = { version = "1.26.0" } +tokio = { version = "1.32.0" } bytes = "1.4" futures = { version = "0.3.27", default-features = false } indexmap = "2.0.0" diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index 9b252b236df6..24fd6305eb04 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -7,22 +7,18 @@ use crate::preview2::bindings::{ use crate::preview2::network::TableNetworkExt; use crate::preview2::poll::TablePollableExt; use crate::preview2::stream::TableStreamExt; -use crate::preview2::tcp::{HostTcpSocket, HostTcpSocketInner, HostTcpState, TableTcpSocketExt}; +use crate::preview2::tcp::{HostTcpSocket, HostTcpState, TableTcpSocketExt}; use crate::preview2::{HostPollable, PollableFuture, WasiView}; use cap_net_ext::{Blocking, PoolExt, TcpListenerExt}; use io_lifetimes::AsSocketlike; use rustix::io::Errno; use rustix::net::sockopt; use std::any::Any; -use std::mem; -use std::pin::Pin; -use std::sync::Arc; use std::sync::RwLockWriteGuard; #[cfg(unix)] -use tokio::task::spawn; +use tokio::io::Interest; #[cfg(not(unix))] use tokio::task::spawn_blocking; -use tokio::task::JoinHandle; impl tcp::Host for T { fn start_bind( @@ -47,7 +43,6 @@ impl tcp::Host for T { binder.bind_existing_tcp_listener(socket.tcp_socket())?; set_state(tcp_state, HostTcpState::BindStarted); - socket.notify(); Ok(()) } @@ -90,8 +85,7 @@ impl tcp::Host for T { match connecter.connect_existing_tcp_listener(socket.tcp_socket()) { // succeed immediately, Ok(()) => { - set_state(tcp_state, HostTcpState::ConnectReady(Ok(()))); - socket.notify(); + set_state(tcp_state, HostTcpState::ConnectReady); return Ok(()); } // continue in progress, @@ -100,54 +94,7 @@ impl tcp::Host for T { Err(err) => return Err(err.into()), } - // The connect is continuing in progress. Set up the join handle. - - let clone = socket.clone_inner(); - - #[cfg(unix)] - let join = spawn(async move { - let result = match clone.tcp_socket.writable().await { - Ok(mut writable) => { - writable.retain_ready(); - - // Check whether the connect succeeded. - match sockopt::get_socket_error(&clone.tcp_socket) { - Ok(Ok(())) => Ok(()), - Err(err) | Ok(Err(err)) => Err(err.into()), - } - } - Err(err) => Err(err), - }; - - clone.set_state_and_notify(HostTcpState::ConnectReady(result)); - }); - - #[cfg(not(unix))] - let join = spawn_blocking(move || { - let result = match rustix::event::poll( - &mut [rustix::event::PollFd::new( - &clone.tcp_socket, - rustix::event::PollFlags::OUT, - )], - -1, - ) { - Ok(_) => { - // Check whether the connect succeeded. - match sockopt::get_socket_error(&clone.tcp_socket) { - Ok(Ok(())) => Ok(()), - Err(err) | Ok(Err(err)) => Err(err.into()), - } - } - Err(err) => Err(err.into()), - }; - - clone.set_state_and_notify(HostTcpState::ConnectReady(result)); - }); - - set_state( - tcp_state, - HostTcpState::Connecting(Pin::from(Box::new(join))), - ); + set_state(tcp_state, HostTcpState::Connecting); Ok(()) } @@ -161,32 +108,18 @@ impl tcp::Host for T { let mut tcp_state = socket.tcp_state_write_lock(); match &mut *tcp_state { - HostTcpState::ConnectReady(_) => {} - HostTcpState::Connecting(join) => match maybe_unwrap_future(join) { - Some(joined) => joined.unwrap(), - None => return Err(ErrorCode::WouldBlock.into()), - }, + HostTcpState::ConnectReady => {} + HostTcpState::Connecting => { + // Check whether the connect succeeded. + match sockopt::get_socket_error(socket.tcp_socket()) { + Ok(Ok(())) => {} + Err(err) | Ok(Err(err)) => return Err(err.into()), + } + } _ => return Err(ErrorCode::NotInProgress.into()), }; - let old_state = mem::replace(&mut *tcp_state, HostTcpState::Connected); - - // Extract the connection result. - let result = match old_state { - HostTcpState::ConnectReady(result) => result, - _ => unreachable!(), - }; - - // Report errors, resetting the state if needed. - match result { - Ok(()) => {} - Err(err) => { - set_state(tcp_state, HostTcpState::Default); - return Err(err.into()); - } - } - - drop(tcp_state); + set_state(tcp_state, HostTcpState::Connected); let input_clone = socket.clone_inner(); let output_clone = socket.clone_inner(); @@ -214,7 +147,6 @@ impl tcp::Host for T { socket.tcp_socket().listen(None)?; set_state(tcp_state, HostTcpState::ListenStarted); - socket.notify(); Ok(()) } @@ -230,11 +162,7 @@ impl tcp::Host for T { _ => return Err(ErrorCode::NotInProgress.into()), } - let new_join = spawn_task_to_wait_for_connections(socket.clone_inner()); - set_state( - tcp_state, - HostTcpState::Listening(Pin::from(Box::new(new_join))), - ); + set_state(tcp_state, HostTcpState::Listening); Ok(()) } @@ -248,20 +176,12 @@ impl tcp::Host for T { let mut tcp_state = socket.tcp_state_write_lock(); match &mut *tcp_state { - HostTcpState::ListenReady(_) => {} - HostTcpState::Listening(join) => match maybe_unwrap_future(join) { - Some(joined) => joined.unwrap(), - None => return Err(ErrorCode::WouldBlock.into()), - }, + HostTcpState::Listening => {} HostTcpState::Connected => return Err(ErrorCode::AlreadyConnected.into()), _ => return Err(ErrorCode::NotInProgress.into()), } - let new_join = spawn_task_to_wait_for_connections(socket.clone_inner()); - set_state( - tcp_state, - HostTcpState::Listening(Pin::from(Box::new(new_join))), - ); + set_state(tcp_state, HostTcpState::Listening); // Do the OS accept call. let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?; @@ -366,7 +286,7 @@ impl tcp::Host for T { let tcp_state = socket.tcp_state_read_lock(); match &*tcp_state { - HostTcpState::Listening(_) => {} + HostTcpState::Listening => {} _ => return Err(ErrorCode::NotInProgress.into()), } @@ -478,10 +398,55 @@ impl tcp::Host for T { .downcast_mut::() .expect("downcast to HostTcpSocket failed"); - Box::pin(async { - socket.receiver.changed().await.unwrap(); + // Some states are ready immediately. + match *socket.tcp_state_read_lock() { + HostTcpState::BindStarted + | HostTcpState::ListenStarted + | HostTcpState::ConnectReady => return Box::pin(async { Ok(()) }), + _ => {} + } + + #[cfg(unix)] + let join = Box::pin(async move { + socket + .inner + .tcp_socket + .ready(Interest::READABLE | Interest::WRITABLE | Interest::ERROR) + .await + .unwrap() + .retain_ready(); + Ok(()) + }); + + #[cfg(not(unix))] + let join = Box::pin(async move { + let clone = socket.clone_inner(); + spawn_blocking(move || loop { + #[cfg(not(windows))] + let poll_flags = rustix::event::PollFlags::IN + | rustix::event::PollFlags::OUT + | rustix::event::PollFlags::ERR + | rustix::event::PollFlags::HUP; + // Windows doesn't appear to support `HUP`, or `ERR` + // combined with `IN`/`OUT`. + #[cfg(windows)] + let poll_flags = rustix::event::PollFlags::IN | rustix::event::PollFlags::OUT; + match rustix::event::poll( + &mut [rustix::event::PollFd::new(&clone.tcp_socket, poll_flags)], + -1, + ) { + Ok(_) => break, + Err(Errno::INTR) => (), + Err(err) => Err(err).unwrap(), + } + }) + .await + .unwrap(); + Ok(()) - }) + }); + + join } let pollable = HostPollable::TableEntry { @@ -529,12 +494,9 @@ impl tcp::Host for T { | HostTcpState::BindStarted | HostTcpState::Bound | HostTcpState::ListenStarted - | HostTcpState::ListenReady(_) - | HostTcpState::ConnectReady(_) => {} + | HostTcpState::ConnectReady => {} - HostTcpState::Listening(_) - | HostTcpState::Connecting(_) - | HostTcpState::Connected => { + HostTcpState::Listening | HostTcpState::Connecting | HostTcpState::Connected => { match rustix::net::shutdown( &dropped.inner.tcp_socket, rustix::net::Shutdown::ReadWrite, @@ -552,76 +514,12 @@ impl tcp::Host for T { } } -/// Spawn a task to monitor a socket for incoming connections that -/// can be `accept`ed. -fn spawn_task_to_wait_for_connections(socket: Arc) -> JoinHandle<()> { - #[cfg(unix)] - let join = spawn(async move { - socket.tcp_socket.readable().await.unwrap().retain_ready(); - socket.set_state_and_notify(HostTcpState::ListenReady(Ok(()))); - }); - - #[cfg(not(unix))] - let join = spawn_blocking(move || { - let result = match rustix::event::poll( - &mut [rustix::event::PollFd::new( - &socket.tcp_socket, - rustix::event::PollFlags::IN, - )], - -1, - ) { - Ok(_) => Ok(()), - Err(err) => Err(err.into()), - }; - socket.set_state_and_notify(HostTcpState::ListenReady(result)); - }); - - join -} - /// Set `*tcp_state` to `new_state` and consume `tcp_state`. fn set_state(tcp_state: RwLockWriteGuard, new_state: HostTcpState) { let mut tcp_state = tcp_state; *tcp_state = new_state; } -/// Given a future, return the finished value if it's already ready, or -/// `None` if it's not. -fn maybe_unwrap_future( - future: &mut Pin>, -) -> Option { - use std::ptr; - use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; - - // Create a no-op Waker. This is derived from [code in std] and can - // be replaced with `std::task::Waker::noop()` when the "noop_waker" - // feature is stablized. - // - // [code in std]: https://github.com/rust-lang/rust/blob/27fb598d51d4566a725e4868eaf5d2e15775193e/library/core/src/task/wake.rs#L349 - fn noop_waker() -> Waker { - const VTABLE: RawWakerVTable = RawWakerVTable::new( - // Cloning just returns a new no-op raw waker - |_| RAW, - // `wake` does nothing - |_| {}, - // `wake_by_ref` does nothing - |_| {}, - // Dropping does nothing as we don't allocate anything - |_| {}, - ); - const RAW: RawWaker = RawWaker::new(ptr::null(), &VTABLE); - - unsafe { Waker::from_raw(RAW) } - } - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - match future.as_mut().poll(&mut cx) { - Poll::Ready(val) => Some(val), - Poll::Pending => None, - } -} - // On POSIX, non-blocking TCP socket `connect` uses `EINPROGRESS`. // #[cfg(not(windows))] diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index f4cfc80ce479..adc815f1842e 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -4,11 +4,8 @@ use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt}; use cap_std::net::{TcpListener, TcpStream}; use io_lifetimes::AsSocketlike; use std::io; -use std::pin::Pin; use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; use system_interface::io::IoExt; -use tokio::sync::watch::{channel, Receiver, Sender}; -use tokio::task::JoinHandle; /// The state of a TCP socket. /// @@ -29,17 +26,13 @@ pub(crate) enum HostTcpState { ListenStarted, /// The socket is now listening and waiting for an incoming connection. - Listening(Pin>>), - - /// Listening heard an incoming connection arrive that is ready to be - /// accepted. - ListenReady(io::Result<()>), + Listening, /// An outgoing connection is started via `start_connect`. - Connecting(Pin>>), + Connecting, /// An outgoing connection is ready to be established. - ConnectReady(io::Result<()>), + ConnectReady, /// An outgoing connection has been established. Connected, @@ -55,10 +48,6 @@ pub(crate) struct HostTcpSocket { /// The part of a `HostTcpSocket` which is reference-counted so that we /// can pass it to async tasks. pub(crate) inner: Arc, - - /// The recieving end of `inner`'s `sender`, used by `subscribe` - /// subscriptions to wait for I/O. - pub(crate) receiver: Receiver<()>, } /// The inner reference-counted state of a `HostTcpSocket`. @@ -73,9 +62,6 @@ pub(crate) struct HostTcpSocketInner { /// The current state in the bind/listen/accept/connect progression. pub(crate) tcp_state: RwLock, - - /// A sender used to send messages when I/O events complete. - pub(crate) sender: Sender<()>, } impl HostTcpSocket { @@ -89,15 +75,11 @@ impl HostTcpSocket { #[cfg(unix)] let tcp_socket = tokio::io::unix::AsyncFd::new(tcp_socket)?; - let (sender, receiver) = channel(()); - Ok(Self { inner: Arc::new(HostTcpSocketInner { tcp_socket, tcp_state: RwLock::new(HostTcpState::Default), - sender, }), - receiver, }) } @@ -112,15 +94,11 @@ impl HostTcpSocket { #[cfg(unix)] let tcp_socket = tokio::io::unix::AsyncFd::new(tcp_socket)?; - let (sender, receiver) = channel(()); - Ok(Self { inner: Arc::new(HostTcpSocketInner { tcp_socket, tcp_state: RwLock::new(HostTcpState::Default), - sender, }), - receiver, }) } @@ -128,10 +106,6 @@ impl HostTcpSocket { self.inner.tcp_socket() } - pub fn notify(&self) { - self.inner.notify() - } - pub fn clone_inner(&self) -> Arc { Arc::clone(&self.inner) } @@ -158,19 +132,10 @@ impl HostTcpSocketInner { tcp_socket } - pub fn notify(&self) { - self.sender.send(()).unwrap() - } - pub fn set_state(&self, new_state: HostTcpState) { *self.tcp_state.write().unwrap() = new_state; } - pub fn set_state_and_notify(&self, new_state: HostTcpState) { - self.set_state(new_state); - self.notify() - } - /// Spawn a task on tokio's blocking thread for performing blocking /// syscalls on the underlying [`cap_std::net::TcpListener`]. #[cfg(not(unix))] @@ -215,7 +180,9 @@ impl HostInputStream for Arc { match rustix::event::poll( &mut [rustix::event::PollFd::new( tcp_socket, - rustix::event::PollFlags::IN, + rustix::event::PollFlags::IN + | rustix::event::PollFlags::ERR + | rustix::event::PollFlags::HUP, )], -1, ) { @@ -255,7 +222,9 @@ impl HostOutputStream for Arc { match rustix::event::poll( &mut [rustix::event::PollFd::new( tcp_socket, - rustix::event::PollFlags::OUT, + rustix::event::PollFlags::OUT + | rustix::event::PollFlags::ERR + | rustix::event::PollFlags::HUP, )], -1, ) { @@ -268,24 +237,6 @@ impl HostOutputStream for Arc { } } -impl Drop for HostTcpSocketInner { - fn drop(&mut self) { - match &*self.tcp_state.read().unwrap() { - HostTcpState::Default - | HostTcpState::BindStarted - | HostTcpState::Bound - | HostTcpState::ListenStarted - | HostTcpState::ListenReady(_) - | HostTcpState::ConnectReady(_) - | HostTcpState::Connected => {} - HostTcpState::Listening(join) | HostTcpState::Connecting(join) => { - // Abort the tasks so that they don't detach. - join.abort(); - } - } - } -} - pub(crate) trait TableTcpSocketExt { fn push_tcp_socket(&mut self, tcp_socket: HostTcpSocket) -> Result; fn delete_tcp_socket(&mut self, fd: u32) -> Result; diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 683f4da23a82..b3ee3c5da423 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -549,7 +549,7 @@ version = "1.8.0" criteria = "safe-to-deploy" [[exemptions.socket2]] -version = "0.4.4" +version = "0.5.3" criteria = "safe-to-deploy" [[exemptions.souper-ir]] @@ -589,7 +589,7 @@ version = "1.2.1" criteria = "safe-to-run" [[exemptions.tokio]] -version = "1.29.1" +version = "1.32.0" criteria = "safe-to-deploy" notes = "we are exempting tokio, hyper, and their tightly coupled dependencies by the same authors, expecting that the authors at aws will publish attestions we can import at some point soon"