Skip to content

Commit

Permalink
Simplify the polling mechanism.
Browse files Browse the repository at this point in the history
This requires an updated tokio for `Interest::ERROR`.
  • Loading branch information
sunfishcode committed Aug 19, 2023
1 parent db35146 commit 8b049ef
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 238 deletions.
15 changes: 7 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ tempfile = "3.1.0"
filecheck = "0.5.0"
libc = "0.2.60"
file-per-thread-logger = "0.2.0"
tokio = { version = "1.26.0" }
tokio = { version = "1.32.0" }
bytes = "1.4"
futures = { version = "0.3.27", default-features = false }
indexmap = "2.0.0"
Expand Down
236 changes: 67 additions & 169 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,18 @@ use crate::preview2::bindings::{
use crate::preview2::network::TableNetworkExt;
use crate::preview2::poll::TablePollableExt;
use crate::preview2::stream::TableStreamExt;
use crate::preview2::tcp::{HostTcpSocket, HostTcpSocketInner, HostTcpState, TableTcpSocketExt};
use crate::preview2::tcp::{HostTcpSocket, HostTcpState, TableTcpSocketExt};
use crate::preview2::{HostPollable, PollableFuture, WasiView};
use cap_net_ext::{Blocking, PoolExt, TcpListenerExt};
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::any::Any;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLockWriteGuard;
#[cfg(unix)]
use tokio::task::spawn;
use tokio::io::Interest;
#[cfg(not(unix))]
use tokio::task::spawn_blocking;
use tokio::task::JoinHandle;

impl<T: WasiView> tcp::Host for T {
fn start_bind(
Expand All @@ -47,7 +43,6 @@ impl<T: WasiView> tcp::Host for T {
binder.bind_existing_tcp_listener(socket.tcp_socket())?;

set_state(tcp_state, HostTcpState::BindStarted);
socket.notify();

Ok(())
}
Expand Down Expand Up @@ -90,8 +85,7 @@ impl<T: WasiView> tcp::Host for T {
match connecter.connect_existing_tcp_listener(socket.tcp_socket()) {
// succeed immediately,
Ok(()) => {
set_state(tcp_state, HostTcpState::ConnectReady(Ok(())));
socket.notify();
set_state(tcp_state, HostTcpState::ConnectReady);
return Ok(());
}
// continue in progress,
Expand All @@ -100,54 +94,7 @@ impl<T: WasiView> tcp::Host for T {
Err(err) => return Err(err.into()),
}

// The connect is continuing in progress. Set up the join handle.

let clone = socket.clone_inner();

#[cfg(unix)]
let join = spawn(async move {
let result = match clone.tcp_socket.writable().await {
Ok(mut writable) => {
writable.retain_ready();

// Check whether the connect succeeded.
match sockopt::get_socket_error(&clone.tcp_socket) {
Ok(Ok(())) => Ok(()),
Err(err) | Ok(Err(err)) => Err(err.into()),
}
}
Err(err) => Err(err),
};

clone.set_state_and_notify(HostTcpState::ConnectReady(result));
});

#[cfg(not(unix))]
let join = spawn_blocking(move || {
let result = match rustix::event::poll(
&mut [rustix::event::PollFd::new(
&clone.tcp_socket,
rustix::event::PollFlags::OUT,
)],
-1,
) {
Ok(_) => {
// Check whether the connect succeeded.
match sockopt::get_socket_error(&clone.tcp_socket) {
Ok(Ok(())) => Ok(()),
Err(err) | Ok(Err(err)) => Err(err.into()),
}
}
Err(err) => Err(err.into()),
};

clone.set_state_and_notify(HostTcpState::ConnectReady(result));
});

set_state(
tcp_state,
HostTcpState::Connecting(Pin::from(Box::new(join))),
);
set_state(tcp_state, HostTcpState::Connecting);

Ok(())
}
Expand All @@ -161,32 +108,18 @@ impl<T: WasiView> tcp::Host for T {

let mut tcp_state = socket.tcp_state_write_lock();
match &mut *tcp_state {
HostTcpState::ConnectReady(_) => {}
HostTcpState::Connecting(join) => match maybe_unwrap_future(join) {
Some(joined) => joined.unwrap(),
None => return Err(ErrorCode::WouldBlock.into()),
},
HostTcpState::ConnectReady => {}
HostTcpState::Connecting => {
// Check whether the connect succeeded.
match sockopt::get_socket_error(socket.tcp_socket()) {
Ok(Ok(())) => {}
Err(err) | Ok(Err(err)) => return Err(err.into()),
}
}
_ => return Err(ErrorCode::NotInProgress.into()),
};

let old_state = mem::replace(&mut *tcp_state, HostTcpState::Connected);

// Extract the connection result.
let result = match old_state {
HostTcpState::ConnectReady(result) => result,
_ => unreachable!(),
};

// Report errors, resetting the state if needed.
match result {
Ok(()) => {}
Err(err) => {
set_state(tcp_state, HostTcpState::Default);
return Err(err.into());
}
}

drop(tcp_state);
set_state(tcp_state, HostTcpState::Connected);

let input_clone = socket.clone_inner();
let output_clone = socket.clone_inner();
Expand Down Expand Up @@ -214,7 +147,6 @@ impl<T: WasiView> tcp::Host for T {
socket.tcp_socket().listen(None)?;

set_state(tcp_state, HostTcpState::ListenStarted);
socket.notify();

Ok(())
}
Expand All @@ -230,11 +162,7 @@ impl<T: WasiView> tcp::Host for T {
_ => return Err(ErrorCode::NotInProgress.into()),
}

let new_join = spawn_task_to_wait_for_connections(socket.clone_inner());
set_state(
tcp_state,
HostTcpState::Listening(Pin::from(Box::new(new_join))),
);
set_state(tcp_state, HostTcpState::Listening);

Ok(())
}
Expand All @@ -248,20 +176,12 @@ impl<T: WasiView> tcp::Host for T {

let mut tcp_state = socket.tcp_state_write_lock();
match &mut *tcp_state {
HostTcpState::ListenReady(_) => {}
HostTcpState::Listening(join) => match maybe_unwrap_future(join) {
Some(joined) => joined.unwrap(),
None => return Err(ErrorCode::WouldBlock.into()),
},
HostTcpState::Listening => {}
HostTcpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
_ => return Err(ErrorCode::NotInProgress.into()),
}

let new_join = spawn_task_to_wait_for_connections(socket.clone_inner());
set_state(
tcp_state,
HostTcpState::Listening(Pin::from(Box::new(new_join))),
);
set_state(tcp_state, HostTcpState::Listening);

// Do the OS accept call.
let (connection, _addr) = socket.tcp_socket().accept_with(Blocking::No)?;
Expand Down Expand Up @@ -366,7 +286,7 @@ impl<T: WasiView> tcp::Host for T {

let tcp_state = socket.tcp_state_read_lock();
match &*tcp_state {
HostTcpState::Listening(_) => {}
HostTcpState::Listening => {}
_ => return Err(ErrorCode::NotInProgress.into()),
}

Expand Down Expand Up @@ -478,10 +398,55 @@ impl<T: WasiView> tcp::Host for T {
.downcast_mut::<HostTcpSocket>()
.expect("downcast to HostTcpSocket failed");

Box::pin(async {
socket.receiver.changed().await.unwrap();
// Some states are ready immediately.
match *socket.tcp_state_read_lock() {
HostTcpState::BindStarted
| HostTcpState::ListenStarted
| HostTcpState::ConnectReady => return Box::pin(async { Ok(()) }),
_ => {}
}

#[cfg(unix)]
let join = Box::pin(async move {
socket
.inner
.tcp_socket
.ready(Interest::READABLE | Interest::WRITABLE | Interest::ERROR)
.await
.unwrap()
.retain_ready();
Ok(())
});

#[cfg(not(unix))]
let join = Box::pin(async move {
let clone = socket.clone_inner();
spawn_blocking(move || loop {
#[cfg(not(windows))]
let poll_flags = rustix::event::PollFlags::IN
| rustix::event::PollFlags::OUT
| rustix::event::PollFlags::ERR
| rustix::event::PollFlags::HUP;
// Windows doesn't appear to support `HUP`, or `ERR`
// combined with `IN`/`OUT`.
#[cfg(windows)]
let poll_flags = rustix::event::PollFlags::IN | rustix::event::PollFlags::OUT;
match rustix::event::poll(
&mut [rustix::event::PollFd::new(&clone.tcp_socket, poll_flags)],
-1,
) {
Ok(_) => break,
Err(Errno::INTR) => (),
Err(err) => Err(err).unwrap(),
}
})
.await
.unwrap();

Ok(())
})
});

join
}

let pollable = HostPollable::TableEntry {
Expand Down Expand Up @@ -529,12 +494,9 @@ impl<T: WasiView> tcp::Host for T {
| HostTcpState::BindStarted
| HostTcpState::Bound
| HostTcpState::ListenStarted
| HostTcpState::ListenReady(_)
| HostTcpState::ConnectReady(_) => {}
| HostTcpState::ConnectReady => {}

HostTcpState::Listening(_)
| HostTcpState::Connecting(_)
| HostTcpState::Connected => {
HostTcpState::Listening | HostTcpState::Connecting | HostTcpState::Connected => {
match rustix::net::shutdown(
&dropped.inner.tcp_socket,
rustix::net::Shutdown::ReadWrite,
Expand All @@ -552,76 +514,12 @@ impl<T: WasiView> tcp::Host for T {
}
}

/// Spawn a task to monitor a socket for incoming connections that
/// can be `accept`ed.
fn spawn_task_to_wait_for_connections(socket: Arc<HostTcpSocketInner>) -> JoinHandle<()> {
#[cfg(unix)]
let join = spawn(async move {
socket.tcp_socket.readable().await.unwrap().retain_ready();
socket.set_state_and_notify(HostTcpState::ListenReady(Ok(())));
});

#[cfg(not(unix))]
let join = spawn_blocking(move || {
let result = match rustix::event::poll(
&mut [rustix::event::PollFd::new(
&socket.tcp_socket,
rustix::event::PollFlags::IN,
)],
-1,
) {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
};
socket.set_state_and_notify(HostTcpState::ListenReady(result));
});

join
}

/// Set `*tcp_state` to `new_state` and consume `tcp_state`.
fn set_state(tcp_state: RwLockWriteGuard<HostTcpState>, new_state: HostTcpState) {
let mut tcp_state = tcp_state;
*tcp_state = new_state;
}

/// Given a future, return the finished value if it's already ready, or
/// `None` if it's not.
fn maybe_unwrap_future<F: std::future::Future + std::marker::Unpin>(
future: &mut Pin<Box<F>>,
) -> Option<F::Output> {
use std::ptr;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

// Create a no-op Waker. This is derived from [code in std] and can
// be replaced with `std::task::Waker::noop()` when the "noop_waker"
// feature is stablized.
//
// [code in std]: https://github.com/rust-lang/rust/blob/27fb598d51d4566a725e4868eaf5d2e15775193e/library/core/src/task/wake.rs#L349
fn noop_waker() -> Waker {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
// Cloning just returns a new no-op raw waker
|_| RAW,
// `wake` does nothing
|_| {},
// `wake_by_ref` does nothing
|_| {},
// Dropping does nothing as we don't allocate anything
|_| {},
);
const RAW: RawWaker = RawWaker::new(ptr::null(), &VTABLE);

unsafe { Waker::from_raw(RAW) }
}

let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
match future.as_mut().poll(&mut cx) {
Poll::Ready(val) => Some(val),
Poll::Pending => None,
}
}

// On POSIX, non-blocking TCP socket `connect` uses `EINPROGRESS`.
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html>
#[cfg(not(windows))]
Expand Down
Loading

0 comments on commit 8b049ef

Please sign in to comment.