Skip to content

Commit

Permalink
add a join_background_tasks method to Host{In,Out}putStream and WasiCtx.
Browse files Browse the repository at this point in the history
This method flushes output and terminates background tasks. Background
tasks still terminate as part of Drop!

The problem with the current implementation is that there is no way to
wait for output buffered in host background tasks to flush before
aborting those tasks as part of dropping the Store/Table. This means that
e.g. for a trivial component that prints "hello world\n" to stdout and
returns, if the Store&Table drop immediately after execution of the
component completes, there is a race and the output may not happen at
all.

I don't really love this design, but we got backed into a corner because
all of the alternatives we could think up are worse:

* We can't just get rid of the abort-on-drop completely ("daemonize" the
tasks)  because that means that streams that are connected to e.g. a
stalled client connection will consume resources forever, which is not
acceptable in some embeddings.
* We can't ensure flushing on drop of a table/store because it requires
an await, and rust does not have an async drop
* We can't add an explicit destructor to a table/store which will
terminate tasks, and if this destructor is not called tasks will
"daemonize", because that means cancellation of the future executing
a component before the explicit destructor is called will end up
daemonizing the task.
* We could configure all AsyncWriteStreams (and any other stream impls
that spawn a task) at creation, or at insertion to the table, with
whether they should daemonize on drop or not. This would mean plumbing a
bunch of config into places it currently is not.

Therefore, the only behavior we could come up with was to keep
the abort-on-drop behavior for background tasks, and add methods
to ensure that background tasks are joined (finished) gracefully.
This means that both sync and async users of WasiCtx will need to
call the appropriate method to wait on background tasks. This is
easy enough for users to miss, but we decided that the alternatives are
worse.

Closes bytecodealliance#6811
  • Loading branch information
Pat Hickey authored and alexcrichton committed Aug 10, 2023
1 parent 137c6f6 commit 5aa24c5
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 30 deletions.
108 changes: 107 additions & 1 deletion crates/wasi/src/preview2/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ use super::clocks::host::{monotonic_clock, wall_clock};
use crate::preview2::{
clocks::{self, HostMonotonicClock, HostWallClock},
filesystem::{Dir, TableFsExt},
pipe, random, stdio,
pipe,
poll::TablePollableExt,
random, stdio,
stream::{HostInputStream, HostOutputStream, TableStreamExt},
DirPerms, FilePerms, Table,
};
use cap_rand::{Rng, RngCore, SeedableRng};
use std::future::Future;
use std::mem;

pub struct WasiCtxBuilder {
Expand Down Expand Up @@ -258,3 +261,106 @@ pub struct WasiCtx {
pub(crate) stdout: u32,
pub(crate) stderr: u32,
}

impl WasiCtx {
/// Wait for all background tasks to join (complete) gracefully, after flushing any
/// buffered output.
///
/// NOTE: This function should be used when [`WasiCtx`] is used in an async embedding
/// (i.e. with [`crate::preview2::command::add_to_linker`]). Use its counterpart
/// `sync_join_background_tasks` in a synchronous embedding (i.e. with
/// [`crate::preview2::command::sync::add_to_linker`].
///
/// In order to implement non-blocking streams, we often often need to offload async
/// operations to background `tokio::task`s. These tasks are aborted when the resources
/// in the `Table` referencing them are dropped. In some cases, this abort may occur before
/// buffered output has been flushed. Use this function to wait for all background tasks to
/// join gracefully.
///
/// In some embeddings, a misbehaving client might cause this graceful exit to await for an
/// unbounded amount of time, so we recommend bounding this with a timeout or other mechanism.
pub fn join_background_tasks<'a>(&mut self, table: &'a mut Table) -> impl Future<Output = ()> {
use std::pin::Pin;
use std::task::{Context, Poll};
let keys = table.keys().cloned().collect::<Vec<u32>>();
let mut set = Vec::new();
// we can't remove an stream from the table if it has any child pollables,
// so first delete all pollables from the table.
for k in keys.iter() {
let _ = table.delete_host_pollable(*k);
}
for k in keys {
match table.delete_output_stream(k) {
Ok(mut ostream) => {
// async block takes ownership of the ostream and flushes it
let f = async move { ostream.join_background_tasks().await };
set.push(Box::pin(f) as _)
}
_ => {}
}
match table.delete_input_stream(k) {
Ok(mut istream) => {
// async block takes ownership of the istream and flushes it
let f = async move { istream.join_background_tasks().await };
set.push(Box::pin(f) as _)
}
_ => {}
}
}
// poll futures until all are ready.
// We can't write this as an `async fn` because we want to eagerly poll on each possible
// join, rather than sequentially awaiting on them.
struct JoinAll(Vec<Pin<Box<dyn Future<Output = ()> + Send>>>);
impl Future for JoinAll {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
// Iterate through the set, polling each future and removing it from the set if it
// is ready:
self.as_mut()
.0
.retain_mut(|fut| match fut.as_mut().poll(cx) {
Poll::Ready(_) => false,
_ => true,
});
// Ready if set is empty:
if self.as_mut().0.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

JoinAll(set)
}

/// Wait for all background tasks to join (complete) gracefully, after flushing any
/// buffered output.
///
/// NOTE: This function should be used when [`WasiCtx`] is used in an synchronous embedding
/// (i.e. with [`crate::preview2::command::sync::add_to_linker`]). Use its counterpart
/// `join_background_tasks` in an async embedding (i.e. with
/// [`crate::preview2::command::add_to_linker`].
///
/// In order to implement non-blocking streams, we often often need to offload async
/// operations to background `tokio::task`s. These tasks are aborted when the resources
/// in the `Table` referencing them are dropped. In some cases, this abort may occur before
/// buffered output has been flushed. Use this function to wait for all background tasks to
/// join gracefully.
///
/// In some embeddings, a misbehaving client might cause this graceful exit to await for an
/// unbounded amount of time, so we recommend providing a timeout for this method.
pub fn sync_join_background_tasks<'a>(
&mut self,
table: &'a mut Table,
timeout: Option<std::time::Duration>,
) {
crate::preview2::in_tokio(async move {
if let Some(timeout) = timeout {
let _ = tokio::time::timeout(timeout, self.join_background_tasks(table)).await;
} else {
self.join_background_tasks(table).await
}
})
}
}
114 changes: 86 additions & 28 deletions crates/wasi/src/preview2/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ pub struct AsyncReadStream {
state: StreamState,
buffer: Option<Result<Bytes, std::io::Error>>,
receiver: tokio::sync::mpsc::Receiver<Result<(Bytes, StreamState), std::io::Error>>,
pub(crate) join_handle: tokio::task::JoinHandle<()>,
// the join handle for the background task is Some until join_background_tasks, after which
// further use of the AsyncReadStream is not allowed.
join_handle: Option<tokio::task::JoinHandle<()>>,
}

impl AsyncReadStream {
Expand Down Expand Up @@ -131,21 +133,37 @@ impl AsyncReadStream {
state: StreamState::Open,
buffer: None,
receiver,
join_handle,
join_handle: Some(join_handle),
}
}
// stdio implementation uses this to determine if the backing tokio runtime has been shutdown and
// restarted:
pub(crate) fn is_finished(&self) -> bool {
assert!(
self.join_handle.is_some(),
"illegal use of AsyncReadStream after join_background_tasks"
);
self.join_handle.as_ref().unwrap().is_finished()
}
}

// Make sure the background task does not outlive the AsyncReadStream handle.
// It will join on its own if it reaches a sender await, but if it is blocked
// on reader.read_buf's await it could hold the reader open indefinitely.
impl Drop for AsyncReadStream {
fn drop(&mut self) {
self.join_handle.abort()
self.join_handle.take().map(|h| h.abort());
}
}

#[async_trait::async_trait]
impl HostInputStream for AsyncReadStream {
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
use tokio::sync::mpsc::error::TryRecvError;
assert!(
self.join_handle.is_some(),
"illegal use of AsyncReadStream after join_background_tasks"
);

match self.buffer.take() {
Some(Ok(mut bytes)) => {
Expand Down Expand Up @@ -188,6 +206,11 @@ impl HostInputStream for AsyncReadStream {
}

async fn ready(&mut self) -> Result<(), Error> {
assert!(
self.join_handle.is_some(),
"illegal use of AsyncReadStream after join_background_tasks"
);

if self.buffer.is_some() || self.state == StreamState::Closed {
return Ok(());
}
Expand All @@ -207,6 +230,9 @@ impl HostInputStream for AsyncReadStream {
}
Ok(())
}
async fn join_background_tasks(&mut self) {
self.join_handle.take().map(|h| h.abort());
}
}

#[derive(Debug)]
Expand All @@ -221,7 +247,9 @@ pub struct AsyncWriteStream {
state: Option<WriteState>,
sender: tokio::sync::mpsc::Sender<Bytes>,
result_receiver: tokio::sync::mpsc::Receiver<Result<StreamState, std::io::Error>>,
join_handle: tokio::task::JoinHandle<()>,
// the join handle for the background task is Some until join_background_tasks, after which
// further use of the AsyncReadStream is not allowed.
join_handle: Option<tokio::task::JoinHandle<()>>,
}

impl AsyncWriteStream {
Expand Down Expand Up @@ -269,7 +297,7 @@ impl AsyncWriteStream {
state: Some(WriteState::Ready),
sender,
result_receiver,
join_handle,
join_handle: Some(join_handle),
}
}

Expand All @@ -290,18 +318,52 @@ impl AsyncWriteStream {
Err(TrySendError::Closed(_)) => unreachable!("task shouldn't die while not closed"),
}
}

// Factored out: when called from a wasm stream operation, there is a strict precondition that
// join_handle is_some, but when called as part of join_background_tasks, join_handle is None.
async fn ready_(&mut self) -> Result<(), Error> {
match &self.state {
Some(WriteState::Pending) => match self.result_receiver.recv().await {
Some(Ok(StreamState::Open)) => {
self.state = Some(WriteState::Ready);
}

Some(Ok(StreamState::Closed)) => {
self.state = None;
}

Some(Err(e)) => {
self.state = Some(WriteState::Err(e));
}

None => {
unreachable!("background task has died")
}
},

Some(WriteState::Ready | WriteState::Err(_)) | None => {}
}

Ok(())
}
}

// Make sure the background task does not outlive the AsyncWriteStream handle.
// In order to make sure output is flushed before dropping, use `join_background_tasks`.
impl Drop for AsyncWriteStream {
fn drop(&mut self) {
self.join_handle.abort()
self.join_handle.take().map(|h| h.abort());
}
}

#[async_trait::async_trait]
impl HostOutputStream for AsyncWriteStream {
fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> {
use tokio::sync::mpsc::error::TryRecvError;
assert!(
self.join_handle.is_some(),
"illegal use of AsyncWriteStream after join_background_tasks"
);

match self.state {
Some(WriteState::Ready) => self.send(bytes),
Expand Down Expand Up @@ -344,28 +406,24 @@ impl HostOutputStream for AsyncWriteStream {
}
}

async fn ready(&mut self) -> Result<(), Error> {
match &self.state {
Some(WriteState::Pending) => match self.result_receiver.recv().await {
Some(Ok(StreamState::Open)) => {
self.state = Some(WriteState::Ready);
}

Some(Ok(StreamState::Closed)) => {
self.state = None;
}

Some(Err(e)) => {
self.state = Some(WriteState::Err(e));
}

None => unreachable!("task shouldn't die while pending"),
},

Some(WriteState::Ready | WriteState::Err(_)) | None => {}
}

Ok(())
async fn ready(&mut self) -> anyhow::Result<()> {
assert!(
self.join_handle.is_some(),
"illegal use of AsyncWriteStream after join_background_tasks"
);
// Body is factored out: call from join_background_tasks does not have the above
// precondition.
self.ready_().await
}

async fn join_background_tasks(&mut self) {
// Do this at most once, after which the rest of methods are not to be used:
if let Some(jh) = self.join_handle.take() {
// Worker task has flushed outputs when it is ready for writing again:
let _ = self.ready_().await;
// Now abort the task:
jh.abort();
};
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/wasi/src/preview2/stdio/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub fn stdin() -> Stdin {
//
// As we can't tell the difference between these two, we assume the latter and restart the
// task.
if guard.join_handle.is_finished() {
if guard.is_finished() {
*guard = init_stdin();
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/wasi/src/preview2/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ pub trait HostInputStream: Send + Sync {
/// Check for read readiness: this method blocks until the stream is ready
/// for reading.
async fn ready(&mut self) -> Result<(), Error>;

/// Terminate all background tasks. Exposed only to the host, not accessible from WebAssembly.
/// Must cancel background tasks even if dropped before completion.
/// No other methods may be used after calling this method.
async fn join_background_tasks(&mut self) {}
}

/// Host trait for implementing the `wasi:io/streams.output-stream` resource:
Expand Down Expand Up @@ -88,6 +93,11 @@ pub trait HostOutputStream: Send + Sync {
/// Check for write readiness: this method blocks until the stream is
/// ready for writing.
async fn ready(&mut self) -> Result<(), Error>;

/// Flush any output which has been buffered, and terminate all background tasks. Exposed only
/// to the host, not accessible from WebAssembly. Must cancel background tasks even if dropped
/// before completion. No other methods may be used after calling this method.
async fn join_background_tasks(&mut self) {}
}

pub(crate) enum InternalInputStream {
Expand Down
5 changes: 5 additions & 0 deletions crates/wasi/src/preview2/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ impl Table {
self.map.contains_key(&key)
}

/// Iterator of the keys in the table.
pub fn keys(&self) -> impl Iterator<Item = &u32> {
self.map.keys()
}

/// Check if the resource at a given index can be downcast to a given type.
/// Note: this will always fail if the resource is already borrowed.
pub fn is<T: Any + Sized>(&self, key: u32) -> bool {
Expand Down

0 comments on commit 5aa24c5

Please sign in to comment.