diff --git a/host/src/filesystem.rs b/host/src/filesystem.rs index d115adc28e51..b1c31d0bd935 100644 --- a/host/src/filesystem.rs +++ b/host/src/filesystem.rs @@ -1,6 +1,6 @@ #![allow(unused_variables)] -use crate::wasi_poll::WasiStream; +use crate::wasi_poll::{InputStream, OutputStream}; use crate::{wasi_filesystem, HostResult, WasiCtx}; use std::{ io::{IoSlice, IoSliceMut}, @@ -656,7 +656,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { &mut self, fd: wasi_filesystem::Descriptor, offset: wasi_filesystem::Filesize, - ) -> HostResult { + ) -> HostResult { let f = self.table_mut().get_file_mut(fd).map_err(convert)?; // Duplicate the file descriptor so that we get an indepenent lifetime. @@ -666,7 +666,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { let reader = FileStream::new_reader(clone, offset); // Box it up. - let boxed: Box = Box::new(reader); + let boxed: Box = Box::new(reader); // Insert the stream view into the table. let index = self.table_mut().push(Box::new(boxed)).map_err(convert)?; @@ -678,7 +678,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { &mut self, fd: wasi_filesystem::Descriptor, offset: wasi_filesystem::Filesize, - ) -> HostResult { + ) -> HostResult { let f = self.table_mut().get_file_mut(fd).map_err(convert)?; // Duplicate the file descriptor so that we get an indepenent lifetime. @@ -688,7 +688,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { let writer = FileStream::new_writer(clone, offset); // Box it up. - let boxed: Box = Box::new(writer); + let boxed: Box = Box::new(writer); // Insert the stream view into the table. let index = self.table_mut().push(Box::new(boxed)).map_err(convert)?; @@ -699,7 +699,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { async fn append_via_stream( &mut self, fd: wasi_filesystem::Descriptor, - ) -> HostResult { + ) -> HostResult { let f = self.table_mut().get_file_mut(fd).map_err(convert)?; // Duplicate the file descriptor so that we get an indepenent lifetime. @@ -709,7 +709,7 @@ impl wasi_filesystem::WasiFilesystem for WasiCtx { let appender = FileStream::new_appender(clone); // Box it up. - let boxed: Box = Box::new(appender); + let boxed: Box = Box::new(appender); // Insert the stream view into the table. let index = self.table_mut().push(Box::new(boxed)).map_err(convert)?; diff --git a/host/src/io.rs b/host/src/io.rs new file mode 100644 index 000000000000..4dcb3ab29626 --- /dev/null +++ b/host/src/io.rs @@ -0,0 +1,154 @@ +use crate::{ + wasi_io::{InputStream, OutputStream, StreamError, WasiIo}, + HostResult, WasiCtx, +}; +use wasi_common::stream::TableStreamExt; + +fn convert(error: wasi_common::Error) -> anyhow::Error { + if let Some(_errno) = error.downcast_ref() { + anyhow::Error::new(StreamError {}) + } else { + error.into() + } +} + +#[async_trait::async_trait] +impl WasiIo for WasiCtx { + async fn drop_input_stream(&mut self, stream: InputStream) -> anyhow::Result<()> { + self.table_mut() + .delete::>(stream) + .map_err(convert)?; + Ok(()) + } + + async fn drop_output_stream(&mut self, stream: OutputStream) -> anyhow::Result<()> { + self.table_mut() + .delete::>(stream) + .map_err(convert)?; + Ok(()) + } + + async fn read( + &mut self, + stream: InputStream, + len: u64, + ) -> HostResult<(Vec, bool), StreamError> { + let s: &mut Box = self + .table_mut() + .get_input_stream_mut(stream) + .map_err(convert)?; + + let mut buffer = vec![0; len.try_into().unwrap()]; + + let (bytes_read, end) = s.read(&mut buffer).await.map_err(convert)?; + + buffer.truncate(bytes_read as usize); + + Ok(Ok((buffer, end))) + } + + async fn write( + &mut self, + stream: OutputStream, + bytes: Vec, + ) -> HostResult { + let s: &mut Box = self + .table_mut() + .get_output_stream_mut(stream) + .map_err(convert)?; + + let bytes_written: u64 = s.write(&bytes).await.map_err(convert)?; + + Ok(Ok(u64::try_from(bytes_written).unwrap())) + } + + async fn skip( + &mut self, + stream: InputStream, + len: u64, + ) -> HostResult<(u64, bool), StreamError> { + let s: &mut Box = self + .table_mut() + .get_input_stream_mut(stream) + .map_err(convert)?; + + let (bytes_skipped, end) = s.skip(len).await.map_err(convert)?; + + Ok(Ok((bytes_skipped, end))) + } + + async fn write_repeated( + &mut self, + stream: OutputStream, + byte: u8, + len: u64, + ) -> HostResult { + let s: &mut Box = self + .table_mut() + .get_output_stream_mut(stream) + .map_err(convert)?; + + let bytes_written: u64 = s.write_repeated(byte, len).await.map_err(convert)?; + + Ok(Ok(bytes_written)) + } + + async fn splice( + &mut self, + _src: InputStream, + _dst: OutputStream, + _len: u64, + ) -> HostResult<(u64, bool), StreamError> { + // TODO: We can't get two streams at the same time because they both + // carry the exclusive lifetime of `self`. When [`get_many_mut`] is + // stabilized, that could allow us to add a `get_many_stream_mut` or + // so which lets us do this. + // + // [`get_many_mut`]: https://doc.rust-lang.org/stable/std/collections/hash_map/struct.HashMap.html#method.get_many_mut + /* + let s: &mut Box = self + .table_mut() + .get_input_stream_mut(src) + .map_err(convert)?; + let d: &mut Box = self + .table_mut() + .get_output_stream_mut(dst) + .map_err(convert)?; + + let bytes_spliced: u64 = s.splice(&mut **d, len).await.map_err(convert)?; + + Ok(bytes_spliced) + */ + + todo!() + } + + async fn forward( + &mut self, + _src: InputStream, + _dst: OutputStream, + ) -> HostResult { + // TODO: We can't get two streams at the same time because they both + // carry the exclusive lifetime of `self`. When [`get_many_mut`] is + // stabilized, that could allow us to add a `get_many_stream_mut` or + // so which lets us do this. + // + // [`get_many_mut`]: https://doc.rust-lang.org/stable/std/collections/hash_map/struct.HashMap.html#method.get_many_mut + /* + let s: &mut Box = self + .table_mut() + .get_input_stream_mut(src) + .map_err(convert)?; + let d: &mut Box = self + .table_mut() + .get_output_stream_mut(dst) + .map_err(convert)?; + + let bytes_spliced: u64 = s.splice(&mut **d, len).await.map_err(convert)?; + + Ok(bytes_spliced) + */ + + todo!() + } +} diff --git a/host/src/lib.rs b/host/src/lib.rs index 37421bbc6ef8..40e188d520f6 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -1,6 +1,7 @@ mod clocks; mod exit; mod filesystem; +mod io; mod logging; mod poll; mod random; @@ -27,6 +28,7 @@ pub fn add_to_linker( wasi_logging::add_to_linker(l, f)?; wasi_stderr::add_to_linker(l, f)?; wasi_poll::add_to_linker(l, f)?; + wasi_io::add_to_linker(l, f)?; wasi_random::add_to_linker(l, f)?; wasi_tcp::add_to_linker(l, f)?; wasi_exit::add_to_linker(l, f)?; diff --git a/host/src/poll.rs b/host/src/poll.rs index 02313b5634e9..23beb00149bf 100644 --- a/host/src/poll.rs +++ b/host/src/poll.rs @@ -1,14 +1,15 @@ use crate::{ wasi_clocks, - wasi_poll::{self, Pollable, Size, StreamError, WasiPoll, WasiStream}, - HostResult, WasiCtx, + wasi_io::{InputStream, OutputStream, StreamError}, + wasi_poll::{Pollable, WasiPoll}, + WasiCtx, }; use wasi_common::clocks::TableMonotonicClockExt; use wasi_common::stream::TableStreamExt; fn convert(error: wasi_common::Error) -> anyhow::Error { if let Some(_errno) = error.downcast_ref() { - anyhow::Error::new(wasi_poll::StreamError {}) + anyhow::Error::new(StreamError {}) } else { error.into() } @@ -18,9 +19,9 @@ fn convert(error: wasi_common::Error) -> anyhow::Error { #[derive(Copy, Clone)] enum PollableEntry { /// Poll for read events. - Read(WasiStream), + Read(InputStream), /// Poll for write events. - Write(WasiStream), + Write(OutputStream), /// Poll for a monotonic-clock timer. MonotonicClock(wasi_clocks::MonotonicClock, wasi_clocks::Instant, bool), } @@ -43,13 +44,13 @@ impl WasiPoll for WasiCtx { for (index, future) in futures.into_iter().enumerate() { match *self.table().get(future).map_err(convert)? { PollableEntry::Read(stream) => { - let wasi_stream: &dyn wasi_common::WasiStream = - self.table().get_stream(stream).map_err(convert)?; + let wasi_stream: &dyn wasi_common::InputStream = + self.table().get_input_stream(stream).map_err(convert)?; poll.subscribe_read(wasi_stream, Userdata::from(index as u64)); } PollableEntry::Write(stream) => { - let wasi_stream: &dyn wasi_common::WasiStream = - self.table().get_stream(stream).map_err(convert)?; + let wasi_stream: &dyn wasi_common::OutputStream = + self.table().get_output_stream(stream).map_err(convert)?; poll.subscribe_write(wasi_stream, Userdata::from(index as u64)); } PollableEntry::MonotonicClock(clock, when, absolute) => { @@ -75,107 +76,13 @@ impl WasiPoll for WasiCtx { Ok(results) } - async fn drop_stream(&mut self, stream: WasiStream) -> anyhow::Result<()> { - self.table_mut() - .delete::>(stream) - .map_err(convert)?; - Ok(()) - } - - async fn read_stream( - &mut self, - stream: WasiStream, - len: Size, - ) -> HostResult<(Vec, bool), StreamError> { - let s: &mut Box = - self.table_mut().get_stream_mut(stream).map_err(convert)?; - - let mut buffer = vec![0; len.try_into().unwrap()]; - - let (bytes_read, end) = s.read(&mut buffer).await.map_err(convert)?; - - buffer.truncate(bytes_read as usize); - - Ok(Ok((buffer, end))) - } - - async fn write_stream( - &mut self, - stream: WasiStream, - bytes: Vec, - ) -> HostResult { - let s: &mut Box = - self.table_mut().get_stream_mut(stream).map_err(convert)?; - - let bytes_written: u64 = s.write(&bytes).await.map_err(convert)?; - - Ok(Ok(Size::try_from(bytes_written).unwrap())) - } - - async fn skip_stream( - &mut self, - stream: WasiStream, - len: u64, - ) -> HostResult<(u64, bool), StreamError> { - let s: &mut Box = - self.table_mut().get_stream_mut(stream).map_err(convert)?; - - let (bytes_skipped, end) = s.skip(len).await.map_err(convert)?; - - Ok(Ok((bytes_skipped, end))) - } - - async fn write_repeated_stream( - &mut self, - stream: WasiStream, - byte: u8, - len: u64, - ) -> HostResult { - let s: &mut Box = - self.table_mut().get_stream_mut(stream).map_err(convert)?; - - let bytes_written: u64 = s.write_repeated(byte, len).await.map_err(convert)?; - - Ok(Ok(bytes_written)) - } - - async fn splice_stream( - &mut self, - _src: WasiStream, - _dst: WasiStream, - _len: u64, - ) -> HostResult<(u64, bool), StreamError> { - // TODO: We can't get two streams at the same time because they both - // carry the exclusive lifetime of `self`. When [`get_many_mut`] is - // stabilized, that could allow us to add a `get_many_stream_mut` or - // so which lets us do this. - // - // [`get_many_mut`]: https://doc.rust-lang.org/stable/std/collections/hash_map/struct.HashMap.html#method.get_many_mut - /* - let s: &mut Box = self - .table_mut() - .get_stream_mut(src) - .map_err(convert)?; - let d: &mut Box = self - .table_mut() - .get_stream_mut(dst) - .map_err(convert)?; - - let bytes_spliced: u64 = s.splice(&mut **d, len).await.map_err(convert)?; - - Ok(bytes_spliced) - */ - - todo!() - } - - async fn subscribe_read(&mut self, stream: WasiStream) -> anyhow::Result { + async fn subscribe_read(&mut self, stream: InputStream) -> anyhow::Result { Ok(self .table_mut() .push(Box::new(PollableEntry::Read(stream)))?) } - async fn subscribe_write(&mut self, stream: WasiStream) -> anyhow::Result { + async fn subscribe_write(&mut self, stream: OutputStream) -> anyhow::Result { Ok(self .table_mut() .push(Box::new(PollableEntry::Write(stream)))?) diff --git a/host/src/tcp.rs b/host/src/tcp.rs index eca61b29569e..131a9f944def 100644 --- a/host/src/tcp.rs +++ b/host/src/tcp.rs @@ -1,7 +1,7 @@ #![allow(unused_variables)] use crate::{ - wasi_poll::WasiStream, + wasi_poll::{InputStream, OutputStream}, wasi_tcp::{ Connection, ConnectionFlags, Errno, IoSize, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, Listener, ListenerFlags, Network, TcpListener, WasiTcp, @@ -84,7 +84,7 @@ impl WasiTcp for WasiCtx { &mut self, listener: Listener, flags: ConnectionFlags, - ) -> HostResult<(Connection, WasiStream), Errno> { + ) -> HostResult<(Connection, InputStream, OutputStream), Errno> { let table = self.table_mut(); let l = table.get_listener_mut(listener)?; @@ -95,19 +95,20 @@ impl WasiTcp for WasiCtx { todo!() } - let (connection, stream) = l.accept(nonblocking).await?; + let (connection, input_stream, output_stream) = l.accept(nonblocking).await?; let connection = table.push(Box::new(connection)).map_err(convert)?; - let stream = table.push(Box::new(stream)).map_err(convert)?; + let input_stream = table.push(Box::new(input_stream)).map_err(convert)?; + let output_stream = table.push(Box::new(output_stream)).map_err(convert)?; - Ok(Ok((connection, stream))) + Ok(Ok((connection, input_stream, output_stream))) } async fn accept_tcp( &mut self, listener: TcpListener, flags: ConnectionFlags, - ) -> HostResult<(Connection, WasiStream, IpSocketAddress), Errno> { + ) -> HostResult<(Connection, InputStream, OutputStream, IpSocketAddress), Errno> { let table = self.table_mut(); let l = table.get_tcp_listener_mut(listener)?; @@ -118,12 +119,13 @@ impl WasiTcp for WasiCtx { todo!() } - let (connection, stream, addr) = l.accept(nonblocking).await?; + let (connection, input_stream, output_stream, addr) = l.accept(nonblocking).await?; let connection = table.push(Box::new(connection)).map_err(convert)?; - let stream = table.push(Box::new(stream)).map_err(convert)?; + let input_stream = table.push(Box::new(input_stream)).map_err(convert)?; + let output_stream = table.push(Box::new(output_stream)).map_err(convert)?; - Ok(Ok((connection, stream, addr.into()))) + Ok(Ok((connection, input_stream, output_stream, addr.into()))) } async fn connect( @@ -132,7 +134,7 @@ impl WasiTcp for WasiCtx { local_address: IpSocketAddress, remote_address: IpSocketAddress, flags: ConnectionFlags, - ) -> HostResult<(Connection, WasiStream), Errno> { + ) -> HostResult<(Connection, InputStream, OutputStream), Errno> { todo!() } diff --git a/host/tests/runtime.rs b/host/tests/runtime.rs index e4db2bdba35d..3d75143c572d 100644 --- a/host/tests/runtime.rs +++ b/host/tests/runtime.rs @@ -2,7 +2,7 @@ use anyhow::Result; use cap_rand::RngCore; use cap_std::{ambient_authority, fs::Dir, time::Duration}; use host::wasi_filesystem::Descriptor; -use host::wasi_poll::WasiStream; +use host::wasi_io::{InputStream, OutputStream}; use host::{add_to_linker, WasiCommand, WasiCtx}; use std::{ io::{Cursor, Write}, @@ -50,8 +50,8 @@ async fn instantiate(path: &str) -> Result<(Store, WasiCommand)> { async fn run_hello_stdout(mut store: Store, wasi: WasiCommand) -> Result<()> { wasi.command( &mut store, - 0 as WasiStream, - 1 as WasiStream, + 0 as InputStream, + 1 as OutputStream, &["gussie", "sparky", "willa"], &[], &[], @@ -64,8 +64,8 @@ async fn run_panic(mut store: Store, wasi: WasiCommand) -> Result<()> { let r = wasi .command( &mut store, - 0 as WasiStream, - 1 as WasiStream, + 0 as InputStream, + 1 as OutputStream, &[ "diesel", "the", @@ -88,8 +88,8 @@ async fn run_panic(mut store: Store, wasi: WasiCommand) -> Result<()> { async fn run_args(mut store: Store, wasi: WasiCommand) -> Result<()> { wasi.command( &mut store, - 0 as WasiStream, - 1 as WasiStream, + 0 as InputStream, + 1 as OutputStream, &["hello", "this", "", "is an argument", "with 🚩 emoji"], &[], &[], @@ -121,9 +121,16 @@ async fn run_random(mut store: Store, wasi: WasiCommand) -> Result<()> store.data_mut().random = Box::new(FakeRng); - wasi.command(&mut store, 0 as WasiStream, 1 as WasiStream, &[], &[], &[]) - .await? - .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) + wasi.command( + &mut store, + 0 as InputStream, + 1 as OutputStream, + &[], + &[], + &[], + ) + .await? + .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) } async fn run_time(mut store: Store, wasi: WasiCommand) -> Result<()> { @@ -171,9 +178,16 @@ async fn run_time(mut store: Store, wasi: WasiCommand) -> Result<()> { store.data_mut().clocks.default_monotonic_clock = Box::new(FakeMonotonicClock { now: Mutex::new(0) }); - wasi.command(&mut store, 0 as WasiStream, 1 as WasiStream, &[], &[], &[]) - .await? - .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) + wasi.command( + &mut store, + 0 as InputStream, + 1 as OutputStream, + &[], + &[], + &[], + ) + .await? + .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) } async fn run_stdin(mut store: Store, wasi: WasiCommand) -> Result<()> { @@ -183,9 +197,16 @@ async fn run_stdin(mut store: Store, wasi: WasiCommand) -> Result<()> { "So rested he by the Tumtum tree", )))); - wasi.command(&mut store, 0 as WasiStream, 1 as WasiStream, &[], &[], &[]) - .await? - .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) + wasi.command( + &mut store, + 0 as InputStream, + 1 as OutputStream, + &[], + &[], + &[], + ) + .await? + .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) } async fn run_poll_stdin(mut store: Store, wasi: WasiCommand) -> Result<()> { @@ -195,9 +216,16 @@ async fn run_poll_stdin(mut store: Store, wasi: WasiCommand) -> Result< "So rested he by the Tumtum tree", )))); - wasi.command(&mut store, 0 as WasiStream, 1 as WasiStream, &[], &[], &[]) - .await? - .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) + wasi.command( + &mut store, + 0 as InputStream, + 1 as OutputStream, + &[], + &[], + &[], + ) + .await? + .map_err(|()| anyhow::anyhow!("command returned with failing exit status")) } async fn run_env(mut store: Store, wasi: WasiCommand) -> Result<()> { @@ -415,8 +443,8 @@ async fn run_with_temp_dir(mut store: Store, wasi: WasiCommand) -> Resu wasi.command( &mut store, - 0 as WasiStream, - 1 as WasiStream, + 0 as InputStream, + 1 as OutputStream, &["program", "/foo"], &[], &[(descriptor, "/foo")], diff --git a/wasi-common/cap-std-sync/src/lib.rs b/wasi-common/cap-std-sync/src/lib.rs index a2f0eb1dd02f..af35c1b323a3 100644 --- a/wasi-common/cap-std-sync/src/lib.rs +++ b/wasi-common/cap-std-sync/src/lib.rs @@ -47,7 +47,12 @@ pub use sched::sched_ctx; use crate::net::Listener; use cap_rand::{Rng, RngCore, SeedableRng}; -use wasi_common::{listener::WasiListener, stream::WasiStream, table::Table, WasiCtx}; +use wasi_common::{ + listener::WasiListener, + stream::{InputStream, OutputStream}, + table::Table, + WasiCtx, +}; pub struct WasiCtxBuilder(WasiCtx); @@ -60,15 +65,15 @@ impl WasiCtxBuilder { Table::new(), )) } - pub fn stdin(mut self, f: Box) -> Self { + pub fn stdin(mut self, f: Box) -> Self { self.0.set_stdin(f); self } - pub fn stdout(mut self, f: Box) -> Self { + pub fn stdout(mut self, f: Box) -> Self { self.0.set_stdout(f); self } - pub fn stderr(mut self, f: Box) -> Self { + pub fn stderr(mut self, f: Box) -> Self { self.0.set_stderr(f); self } diff --git a/wasi-common/cap-std-sync/src/net.rs b/wasi-common/cap-std-sync/src/net.rs index 30eb430f4bf8..eecb4213d77f 100644 --- a/wasi-common/cap-std-sync/src/net.rs +++ b/wasi-common/cap-std-sync/src/net.rs @@ -1,4 +1,4 @@ -use io_extras::borrowed::BorrowedWriteable; +use io_extras::borrowed::BorrowedReadable; #[cfg(windows)] use io_extras::os::windows::{AsHandleOrSocket, BorrowedHandleOrSocket}; use io_lifetimes::AsSocketlike; @@ -17,7 +17,7 @@ use system_interface::io::ReadReady; use wasi_common::{ connection::{RiFlags, RoFlags, SdFlags, SiFlags, WasiConnection}, listener::WasiListener, - stream::WasiStream, + stream::{InputStream, OutputStream}, tcp_listener::WasiTcpListener, Error, ErrorExt, }; @@ -109,12 +109,24 @@ macro_rules! wasi_listener_impl { async fn accept( &mut self, nonblocking: bool, - ) -> Result<(Box, Box), Error> { + ) -> Result< + ( + Box, + Box, + Box, + ), + Error, + > { let (stream, _) = self.0.accept()?; stream.set_nonblocking(nonblocking)?; let connection = <$stream>::from_cap_std(stream); - let stream = connection.clone(); - Ok((Box::new(connection), Box::new(stream))) + let input_stream = connection.clone(); + let output_stream = connection.clone(); + Ok(( + Box::new(connection), + Box::new(input_stream), + Box::new(output_stream), + )) } fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { @@ -159,12 +171,26 @@ macro_rules! wasi_tcp_listener_impl { async fn accept( &mut self, nonblocking: bool, - ) -> Result<(Box, Box, SocketAddr), Error> { + ) -> Result< + ( + Box, + Box, + Box, + SocketAddr, + ), + Error, + > { let (stream, addr) = self.0.accept()?; stream.set_nonblocking(nonblocking)?; let connection = <$stream>::from_cap_std(stream); - let stream = connection.clone(); - Ok((Box::new(connection), Box::new(stream), addr)) + let input_stream = connection.clone(); + let output_stream = connection.clone(); + Ok(( + Box::new(connection), + Box::new(input_stream), + Box::new(output_stream), + addr, + )) } fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { @@ -288,7 +314,7 @@ macro_rules! wasi_stream_write_impl { } #[async_trait::async_trait] - impl WasiStream for $ty { + impl InputStream for $ty { fn as_any(&self) -> &dyn Any { self } @@ -296,19 +322,11 @@ macro_rules! wasi_stream_write_impl { fn pollable_read(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(unix)] - fn pollable_write(&self) -> Option { - Some(self.0.as_fd()) - } #[cfg(windows)] fn pollable_read(&self) -> Option { Some(self.0.as_handle_or_socket()) } - #[cfg(windows)] - fn pollable_write(&self) -> Option { - Some(self.0.as_handle_or_socket()) - } async fn read(&mut self, buf: &mut [u8]) -> Result<(u64, bool), Error> { match Read::read(&mut &*self.as_socketlike_view::<$std_ty>(), buf) { @@ -333,6 +351,41 @@ macro_rules! wasi_stream_write_impl { fn is_read_vectored(&self) -> bool { Read::is_read_vectored(&mut &*self.as_socketlike_view::<$std_ty>()) } + + async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { + let num = io::copy(&mut io::Read::take(&*self.0, nelem), &mut io::sink())?; + Ok((num, num < nelem)) + } + + async fn num_ready_bytes(&self) -> Result { + let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?; + Ok(val) + } + + async fn readable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.0 { + Ok(()) + } else { + Err(Error::badf()) + } + } + } + #[async_trait::async_trait] + impl OutputStream for $ty { + fn as_any(&self) -> &dyn Any { + self + } + + #[cfg(unix)] + fn pollable_write(&self) -> Option { + Some(self.0.as_fd()) + } + + #[cfg(windows)] + fn pollable_write(&self) -> Option { + Some(self.0.as_handle_or_socket()) + } + async fn write(&mut self, buf: &[u8]) -> Result { let n = Write::write(&mut &*self.as_socketlike_view::<$std_ty>(), buf)?; Ok(n.try_into()?) @@ -347,38 +400,23 @@ macro_rules! wasi_stream_write_impl { } async fn splice( &mut self, - dst: &mut dyn WasiStream, + src: &mut dyn InputStream, nelem: u64, ) -> Result<(u64, bool), Error> { - if let Some(writeable) = dst.pollable_write() { + if let Some(readable) = src.pollable_read() { let num = io::copy( - &mut io::Read::take(&*self.0, nelem), - &mut BorrowedWriteable::borrow(writeable), + &mut io::Read::take(BorrowedReadable::borrow(readable), nelem), + &mut &*self.0, )?; Ok((num, num < nelem)) } else { - WasiStream::splice(self, dst, nelem).await + OutputStream::splice(self, src, nelem).await } } - async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { - let num = io::copy(&mut io::Read::take(&*self.0, nelem), &mut io::sink())?; - Ok((num, num < nelem)) - } async fn write_repeated(&mut self, byte: u8, nelem: u64) -> Result { let num = io::copy(&mut io::Read::take(io::repeat(byte), nelem), &mut &*self.0)?; Ok(num) } - async fn num_ready_bytes(&self) -> Result { - let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?; - Ok(val) - } - async fn readable(&self) -> Result<(), Error> { - if is_read_write(&*self.0)?.0 { - Ok(()) - } else { - Err(Error::badf()) - } - } async fn writable(&self) -> Result<(), Error> { if is_read_write(&*self.0)?.1 { Ok(()) diff --git a/wasi-common/cap-std-sync/src/sched.rs b/wasi-common/cap-std-sync/src/sched.rs index eff89bbde53f..94d5a2c14f06 100644 --- a/wasi-common/cap-std-sync/src/sched.rs +++ b/wasi-common/cap-std-sync/src/sched.rs @@ -1,7 +1,7 @@ use rustix::io::{PollFd, PollFlags}; use std::thread; use std::time::Duration; -use wasi_common::sched::subscription::{RwEventFlags, RwSubscriptionKind}; +use wasi_common::sched::subscription::{RwEventFlags, RwStream}; use wasi_common::{ sched::{Poll, WasiSched}, Error, ErrorExt, @@ -12,11 +12,11 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { // separately below. let mut ready = false; let mut pollfds = Vec::new(); - for (rwsub, kind) in poll.rw_subscriptions() { - match kind { - RwSubscriptionKind::Read => { + for rwsub in poll.rw_subscriptions() { + match rwsub.stream { + RwStream::Read(stream) => { // Poll things that can be polled. - if let Some(fd) = rwsub.stream.pollable_read() { + if let Some(fd) = stream.pollable_read() { #[cfg(unix)] { pollfds.push(PollFd::from_borrowed_fd(fd, PollFlags::IN)); @@ -34,7 +34,7 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { // Allow in-memory buffers or other immediately-available // sources to complete successfully. - if let Ok(nbytes) = rwsub.stream.num_ready_bytes().await { + if let Ok(nbytes) = stream.num_ready_bytes().await { if nbytes != 0 { rwsub.complete(RwEventFlags::empty()); ready = true; @@ -45,8 +45,8 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { return Err(Error::invalid_argument().context("stream is not pollable for reading")); } - RwSubscriptionKind::Write => { - let fd = rwsub.stream.pollable_write().ok_or( + RwStream::Write(stream) => { + let fd = stream.pollable_write().ok_or( Error::invalid_argument().context("stream is not pollable for writing"), )?; @@ -98,22 +98,13 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { } } - assert_eq!( - poll.rw_subscriptions() - .filter(|(sub, _kind)| !sub.is_complete()) - .count(), - pollfds.len() - ); + assert_eq!(poll.rw_subscriptions().count(), pollfds.len()); // If the OS `poll` returned events, record them. if ready { // Iterate through the stream subscriptions, skipping those that // were already completed due to being immediately available. - for ((rwsub, _kind), pollfd) in poll - .rw_subscriptions() - .filter(|(sub, _kind)| !sub.is_complete()) - .zip(pollfds.into_iter()) - { + for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) { let revents = pollfd.revents(); if revents.contains(PollFlags::NVAL) { rwsub.error(Error::badf()); diff --git a/wasi-common/cap-std-sync/src/stdio.rs b/wasi-common/cap-std-sync/src/stdio.rs index a8aeb456f766..3087f477d15d 100644 --- a/wasi-common/cap-std-sync/src/stdio.rs +++ b/wasi-common/cap-std-sync/src/stdio.rs @@ -10,7 +10,10 @@ use io_extras::os::windows::{AsHandleOrSocket, BorrowedHandleOrSocket}; use io_lifetimes::{AsFd, BorrowedFd}; #[cfg(windows)] use io_lifetimes::{AsHandle, BorrowedHandle}; -use wasi_common::{stream::WasiStream, Error, ErrorExt}; +use wasi_common::{ + stream::{InputStream, OutputStream}, + Error, ErrorExt, +}; pub struct Stdin(std::io::Stdin); @@ -19,7 +22,7 @@ pub fn stdin() -> Stdin { } #[async_trait::async_trait] -impl WasiStream for Stdin { +impl InputStream for Stdin { fn as_any(&self) -> &dyn Any { self } @@ -56,27 +59,6 @@ impl WasiStream for Stdin { fn is_read_vectored(&self) { Read::is_read_vectored(&mut self.0) } - async fn write(&mut self, _buf: &[u8]) -> Result { - Err(Error::badf()) - } - async fn write_vectored<'a>(&mut self, _bufs: &[io::IoSlice<'a>]) -> Result { - Err(Error::badf()) - } - #[cfg(can_vector)] - fn is_write_vectored(&self) { - false - } - - // TODO: Optimize for stdio streams. - /* - async fn splice( - &mut self, - dst: &mut dyn WasiStream, - nelem: u64, - ) -> Result { - todo!() - } - */ async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { let num = io::copy(&mut io::Read::take(&mut self.0, nelem), &mut io::sink())?; @@ -90,10 +72,6 @@ impl WasiStream for Stdin { async fn readable(&self) -> Result<(), Error> { Err(Error::badf()) } - - async fn writable(&self) -> Result<(), Error> { - Ok(()) - } } #[cfg(windows)] impl AsHandle for Stdin { @@ -118,7 +96,7 @@ impl AsFd for Stdin { macro_rules! wasi_file_write_impl { ($ty:ty, $ident:ident) => { #[async_trait::async_trait] - impl WasiStream for $ty { + impl OutputStream for $ty { fn as_any(&self) -> &dyn Any { self } @@ -132,19 +110,6 @@ macro_rules! wasi_file_write_impl { Some(self.0.as_handle_or_socket()) } - async fn read(&mut self, _buf: &mut [u8]) -> Result<(u64, bool), Error> { - Err(Error::badf()) - } - async fn read_vectored<'a>( - &mut self, - _bufs: &mut [io::IoSliceMut<'a>], - ) -> Result<(u64, bool), Error> { - Err(Error::badf()) - } - #[cfg(can_vector)] - fn is_read_vectored(&self) { - false - } async fn write(&mut self, buf: &[u8]) -> Result { let n = Write::write(&mut self.0, buf)?; Ok(n.try_into()?) @@ -161,7 +126,7 @@ macro_rules! wasi_file_write_impl { /* async fn splice( &mut self, - dst: &mut dyn WasiStream, + src: &mut dyn InputStream, nelem: u64, ) -> Result { todo!() @@ -173,10 +138,6 @@ macro_rules! wasi_file_write_impl { Ok(num) } - async fn readable(&self) -> Result<(), Error> { - Err(Error::badf()) - } - async fn writable(&self) -> Result<(), Error> { Ok(()) } diff --git a/wasi-common/src/ctx.rs b/wasi-common/src/ctx.rs index 869cb30474fd..1b4a4968dcca 100644 --- a/wasi-common/src/ctx.rs +++ b/wasi-common/src/ctx.rs @@ -3,7 +3,7 @@ use crate::dir::WasiDir; use crate::file::WasiFile; use crate::listener::WasiListener; use crate::sched::WasiSched; -use crate::stream::WasiStream; +use crate::stream::{InputStream, OutputStream}; use crate::table::Table; use crate::Error; use cap_rand::RngCore; @@ -38,7 +38,11 @@ impl WasiCtx { self.table_mut().insert_at(fd, Box::new(file)); } - pub fn insert_stream(&mut self, fd: u32, stream: Box) { + pub fn insert_input_stream(&mut self, fd: u32, stream: Box) { + self.table_mut().insert_at(fd, Box::new(stream)); + } + + pub fn insert_output_stream(&mut self, fd: u32, stream: Box) { self.table_mut().insert_at(fd, Box::new(stream)); } @@ -66,15 +70,15 @@ impl WasiCtx { &mut self.table } - pub fn set_stdin(&mut self, s: Box) { - self.insert_stream(0, s); + pub fn set_stdin(&mut self, s: Box) { + self.insert_input_stream(0, s); } - pub fn set_stdout(&mut self, s: Box) { - self.insert_stream(1, s); + pub fn set_stdout(&mut self, s: Box) { + self.insert_output_stream(1, s); } - pub fn set_stderr(&mut self, s: Box) { - self.insert_stream(2, s); + pub fn set_stderr(&mut self, s: Box) { + self.insert_output_stream(2, s); } } diff --git a/wasi-common/src/file.rs b/wasi-common/src/file.rs index f8f3eef479b0..77e1043e8a17 100644 --- a/wasi-common/src/file.rs +++ b/wasi-common/src/file.rs @@ -1,4 +1,4 @@ -use crate::{Error, ErrorExt, SystemTimeSpec, WasiStream}; +use crate::{Error, ErrorExt, InputStream, OutputStream, SystemTimeSpec}; use bitflags::bitflags; use std::any::Any; use std::io; @@ -236,7 +236,7 @@ impl FileStream { } #[async_trait::async_trait] -impl WasiStream for FileStream { +impl InputStream for FileStream { fn as_any(&self) -> &dyn Any { self } @@ -250,15 +250,6 @@ impl WasiStream for FileStream { } } - #[cfg(unix)] - fn pollable_write(&self) -> Option { - if let FileStreamType::Read(_) = self.type_ { - None - } else { - self.file.pollable() - } - } - #[cfg(windows)] fn pollable_read(&self) -> Option { if let FileStreamType::Read(_) = self.type_ { @@ -268,15 +259,6 @@ impl WasiStream for FileStream { } } - #[cfg(windows)] - fn pollable_write(&self) -> Option { - if let FileStreamType::Read(_) = self.type_ { - None - } else { - self.file.pollable() - } - } - async fn read(&mut self, buf: &mut [u8]) -> Result<(u64, bool), Error> { if let FileStreamType::Read(position) = &mut self.type_ { let (n, end) = self.file.read_at(buf, *position).await?; @@ -309,6 +291,68 @@ impl WasiStream for FileStream { } } + async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { + // For a zero-length request, don't do the 1-byte check below. + if nelem == 0 { + return self.file.read_at(&mut [], 0).await; + } + + if let FileStreamType::Read(position) = &mut self.type_ { + let new_position = position.checked_add(nelem).ok_or_else(Error::overflow)?; + + let file_size = self.file.get_filestat().await?.size; + + let short_by = new_position.saturating_sub(file_size); + + *position = new_position - short_by; + Ok((nelem - short_by, false)) + } else { + Err(Error::badf()) + } + } + + async fn num_ready_bytes(&self) -> Result { + if let FileStreamType::Read(_) = self.type_ { + // Default to saying that no data is ready. + Ok(0) + } else { + Err(Error::badf()) + } + } + + async fn readable(&self) -> Result<(), Error> { + if let FileStreamType::Read(_) = self.type_ { + self.file.readable().await + } else { + Err(Error::badf()) + } + } +} + +#[async_trait::async_trait] +impl OutputStream for FileStream { + fn as_any(&self) -> &dyn Any { + self + } + + #[cfg(unix)] + fn pollable_write(&self) -> Option { + if let FileStreamType::Read(_) = self.type_ { + None + } else { + self.file.pollable() + } + } + + #[cfg(windows)] + fn pollable_write(&self) -> Option { + if let FileStreamType::Read(_) = self.type_ { + None + } else { + self.file.pollable() + } + } + async fn write(&mut self, buf: &[u8]) -> Result { match &mut self.type_ { FileStreamType::Write(position) => { @@ -352,33 +396,13 @@ impl WasiStream for FileStream { /* async fn splice( &mut self, - dst: &mut dyn WasiStream, + src: &mut dyn InputStream, nelem: u64, ) -> Result { todo!() } */ - async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { - // For a zero-length request, don't do the 1-byte check below. - if nelem == 0 { - return self.file.read_at(&mut [], 0).await; - } - - if let FileStreamType::Read(position) = &mut self.type_ { - let new_position = position.checked_add(nelem).ok_or_else(Error::overflow)?; - - let file_size = self.file.get_filestat().await?.size; - - let short_by = new_position.saturating_sub(file_size); - - *position = new_position - short_by; - Ok((nelem - short_by, false)) - } else { - Err(Error::badf()) - } - } - // TODO: Optimize for file streams. /* async fn write_repeated( @@ -390,23 +414,6 @@ impl WasiStream for FileStream { } */ - async fn num_ready_bytes(&self) -> Result { - if let FileStreamType::Read(_) = self.type_ { - // Default to saying that no data is ready. - Ok(0) - } else { - Err(Error::badf()) - } - } - - async fn readable(&self) -> Result<(), Error> { - if let FileStreamType::Read(_) = self.type_ { - self.file.readable().await - } else { - Err(Error::badf()) - } - } - async fn writable(&self) -> Result<(), Error> { if let FileStreamType::Read(_) = self.type_ { Err(Error::badf()) diff --git a/wasi-common/src/lib.rs b/wasi-common/src/lib.rs index 187a61d595ae..c3d70f7d7a6f 100644 --- a/wasi-common/src/lib.rs +++ b/wasi-common/src/lib.rs @@ -75,6 +75,6 @@ pub use error::{Errno, Error, ErrorExt, I32Exit}; pub use file::WasiFile; pub use listener::WasiListener; pub use sched::{Poll, WasiSched}; -pub use stream::WasiStream; +pub use stream::{InputStream, OutputStream}; pub use table::Table; pub use tcp_listener::WasiTcpListener; diff --git a/wasi-common/src/listener.rs b/wasi-common/src/listener.rs index f83d13f697f8..b8eb34a71741 100644 --- a/wasi-common/src/listener.rs +++ b/wasi-common/src/listener.rs @@ -2,7 +2,7 @@ use crate::connection::WasiConnection; use crate::Error; -use crate::WasiStream; +use crate::{InputStream, OutputStream}; use std::any::Any; /// A socket listener. @@ -13,7 +13,14 @@ pub trait WasiListener: Send + Sync { async fn accept( &mut self, nonblocking: bool, - ) -> Result<(Box, Box), Error>; + ) -> Result< + ( + Box, + Box, + Box, + ), + Error, + >; fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; } diff --git a/wasi-common/src/pipe.rs b/wasi-common/src/pipe.rs index e699f074eb5c..26d6f58ede7a 100644 --- a/wasi-common/src/pipe.rs +++ b/wasi-common/src/pipe.rs @@ -7,8 +7,8 @@ //! Some convenience constructors are included for common backing types like `Vec` and `String`, //! but the virtual pipes can be instantiated with any `Read` or `Write` type. //! -use crate::stream::WasiStream; -use crate::{Error, ErrorExt}; +use crate::stream::{InputStream, OutputStream}; +use crate::Error; use std::any::Any; use std::convert::TryInto; use std::io::{self, Read, Write}; @@ -106,7 +106,7 @@ impl From<&str> for ReadPipe> { } #[async_trait::async_trait] -impl WasiStream for ReadPipe { +impl InputStream for ReadPipe { fn as_any(&self) -> &dyn Any { self } @@ -124,17 +124,6 @@ impl WasiStream for ReadPipe { } } - // TODO: Optimize for pipes. - /* - async fn splice( - &mut self, - dst: &mut dyn WasiStream, - nelem: u64, - ) -> Result { - todo!() - } - */ - async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { let num = io::copy( &mut io::Read::take(&mut *self.borrow(), nelem), @@ -146,10 +135,6 @@ impl WasiStream for ReadPipe { async fn readable(&self) -> Result<(), Error> { Ok(()) } - - async fn writable(&self) -> Result<(), Error> { - Err(Error::badf()) - } } /// A virtual pipe write end. @@ -223,7 +208,7 @@ impl WritePipe>> { } #[async_trait::async_trait] -impl WasiStream for WritePipe { +impl OutputStream for WritePipe { fn as_any(&self) -> &dyn Any { self } @@ -237,7 +222,7 @@ impl WasiStream for WritePipe { /* async fn splice( &mut self, - dst: &mut dyn WasiStream, + src: &mut dyn InputStream, nelem: u64, ) -> Result { todo!() @@ -252,10 +237,6 @@ impl WasiStream for WritePipe { Ok(num) } - async fn readable(&self) -> Result<(), Error> { - Err(Error::badf()) - } - async fn writable(&self) -> Result<(), Error> { Ok(()) } diff --git a/wasi-common/src/sched.rs b/wasi-common/src/sched.rs index b81d4b174930..223d494f0e4f 100644 --- a/wasi-common/src/sched.rs +++ b/wasi-common/src/sched.rs @@ -1,12 +1,11 @@ use crate::clocks::WasiMonotonicClock; -use crate::stream::WasiStream; +use crate::stream::{InputStream, OutputStream}; use crate::Error; pub mod subscription; pub use cap_std::time::Duration; pub use subscription::{ - MonotonicClockSubscription, RwEventFlags, RwSubscription, RwSubscriptionKind, Subscription, - SubscriptionResult, + MonotonicClockSubscription, RwEventFlags, RwSubscription, Subscription, SubscriptionResult, }; #[async_trait::async_trait] @@ -58,15 +57,15 @@ impl<'a> Poll<'a> { ud, )); } - pub fn subscribe_read(&mut self, stream: &'a dyn WasiStream, ud: Userdata) { + pub fn subscribe_read(&mut self, stream: &'a dyn InputStream, ud: Userdata) { self.subs.push(( - Subscription::ReadWrite(RwSubscription::new(stream), RwSubscriptionKind::Read), + Subscription::ReadWrite(RwSubscription::new_input(stream)), ud, )); } - pub fn subscribe_write(&mut self, stream: &'a dyn WasiStream, ud: Userdata) { + pub fn subscribe_write(&mut self, stream: &'a dyn OutputStream, ud: Userdata) { self.subs.push(( - Subscription::ReadWrite(RwSubscription::new(stream), RwSubscriptionKind::Write), + Subscription::ReadWrite(RwSubscription::new_output(stream)), ud, )); } @@ -87,11 +86,9 @@ impl<'a> Poll<'a> { }) .min_by(|a, b| a.deadline.cmp(&b.deadline)) } - pub fn rw_subscriptions<'b>( - &'b mut self, - ) -> impl Iterator, RwSubscriptionKind)> { - self.subs.iter_mut().filter_map(|(s, _ud)| match s { - Subscription::ReadWrite(sub, kind) => Some((sub, *kind)), + pub fn rw_subscriptions<'b>(&'b mut self) -> impl Iterator> { + self.subs.iter_mut().filter_map(|sub| match &mut sub.0 { + Subscription::ReadWrite(rwsub) => Some(rwsub), _ => None, }) } diff --git a/wasi-common/src/sched/subscription.rs b/wasi-common/src/sched/subscription.rs index 47f62d2fcc32..73918a0e59f9 100644 --- a/wasi-common/src/sched/subscription.rs +++ b/wasi-common/src/sched/subscription.rs @@ -1,5 +1,5 @@ use crate::clocks::WasiMonotonicClock; -use crate::stream::WasiStream; +use crate::stream::{InputStream, OutputStream}; use crate::Error; use bitflags::bitflags; @@ -9,15 +9,26 @@ bitflags! { } } +pub enum RwStream<'a> { + Read(&'a dyn InputStream), + Write(&'a dyn OutputStream), +} + pub struct RwSubscription<'a> { - pub stream: &'a dyn WasiStream, + pub stream: RwStream<'a>, status: Option>, } impl<'a> RwSubscription<'a> { - pub fn new(stream: &'a dyn WasiStream) -> Self { + pub fn new_input(stream: &'a dyn InputStream) -> Self { + Self { + stream: RwStream::Read(stream), + status: None, + } + } + pub fn new_output(stream: &'a dyn OutputStream) -> Self { Self { - stream, + stream: RwStream::Write(stream), status: None, } } @@ -57,28 +68,22 @@ impl<'a> MonotonicClockSubscription<'a> { } pub enum Subscription<'a> { - ReadWrite(RwSubscription<'a>, RwSubscriptionKind), + ReadWrite(RwSubscription<'a>), MonotonicClock(MonotonicClockSubscription<'a>), } -#[derive(Copy, Clone, Debug)] -pub enum RwSubscriptionKind { - Read, - Write, -} - #[derive(Debug)] pub enum SubscriptionResult { - ReadWrite(Result, RwSubscriptionKind), + ReadWrite(Result), MonotonicClock(Result<(), Error>), } impl SubscriptionResult { pub fn from_subscription(s: Subscription) -> Option { match s { - Subscription::ReadWrite(mut s, kind) => s - .result() - .map(|sub| SubscriptionResult::ReadWrite(sub, kind)), + Subscription::ReadWrite(mut s) => { + s.result().map(|sub| SubscriptionResult::ReadWrite(sub)) + } Subscription::MonotonicClock(s) => s.result().map(SubscriptionResult::MonotonicClock), } } diff --git a/wasi-common/src/stream.rs b/wasi-common/src/stream.rs index b72ce4bc1e43..4cc230eb0b14 100644 --- a/wasi-common/src/stream.rs +++ b/wasi-common/src/stream.rs @@ -1,13 +1,13 @@ use crate::{Error, ErrorExt}; use std::any::Any; -/// A pseudo-stream. +/// An input bytestream. /// /// This is "pseudo" because the real streams will be a type in wit, and /// built into the wit bindings, and will support async and type parameters. /// This pseudo-stream abstraction is synchronous and only supports bytes. #[async_trait::async_trait] -pub trait WasiStream: Send + Sync { +pub trait InputStream: Send + Sync { fn as_any(&self) -> &dyn Any; /// If this stream is reading from a host file descriptor, return it so @@ -24,20 +24,6 @@ pub trait WasiStream: Send + Sync { None } - /// If this stream is writing from a host file descriptor, return it so - /// that it can be polled with a host poll. - #[cfg(unix)] - fn pollable_write(&self) -> Option { - None - } - - /// If this stream is writing from a host file descriptor, return it so - /// that it can be polled with a host poll. - #[cfg(windows)] - fn pollable_write(&self) -> Option { - None - } - /// Read bytes. On success, returns a pair holding the number of bytes read /// and a flag indicating whether the end of the stream was reached. async fn read(&mut self, _buf: &mut [u8]) -> Result<(u64, bool), Error> { @@ -58,6 +44,56 @@ pub trait WasiStream: Send + Sync { false } + /// Read bytes from a stream and discard them. + async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { + let mut nread = 0; + let mut saw_end = false; + + // TODO: Optimize by reading more than one byte at a time. + for _ in 0..nelem { + let (num, end) = self.read(&mut [0]).await?; + nread += num; + if end { + saw_end = true; + break; + } + } + + Ok((nread, saw_end)) + } + + /// Return the number of bytes that may be read without blocking. + async fn num_ready_bytes(&self) -> Result { + Ok(0) + } + + /// Test whether this stream is readable. + async fn readable(&self) -> Result<(), Error>; +} + +/// An output bytestream. +/// +/// This is "pseudo" because the real streams will be a type in wit, and +/// built into the wit bindings, and will support async and type parameters. +/// This pseudo-stream abstraction is synchronous and only supports bytes. +#[async_trait::async_trait] +pub trait OutputStream: Send + Sync { + fn as_any(&self) -> &dyn Any; + + /// If this stream is writing from a host file descriptor, return it so + /// that it can be polled with a host poll. + #[cfg(unix)] + fn pollable_write(&self) -> Option { + None + } + + /// If this stream is writing from a host file descriptor, return it so + /// that it can be polled with a host poll. + #[cfg(windows)] + fn pollable_write(&self) -> Option { + None + } + /// Write bytes. On success, returns the number of bytes written. async fn write(&mut self, _buf: &[u8]) -> Result { Err(Error::badf()) @@ -75,15 +111,19 @@ pub trait WasiStream: Send + Sync { } /// Transfer bytes directly from an input stream to an output stream. - async fn splice(&mut self, dst: &mut dyn WasiStream, nelem: u64) -> Result<(u64, bool), Error> { + async fn splice( + &mut self, + src: &mut dyn InputStream, + nelem: u64, + ) -> Result<(u64, bool), Error> { let mut nspliced = 0; let mut saw_end = false; // TODO: Optimize by splicing more than one byte at a time. for _ in 0..nelem { let mut buf = [0u8]; - let (num, end) = self.read(&mut buf).await?; - dst.write(&buf).await?; + let (num, end) = src.read(&mut buf).await?; + self.write(&buf).await?; nspliced += num; if end { saw_end = true; @@ -94,24 +134,6 @@ pub trait WasiStream: Send + Sync { Ok((nspliced, saw_end)) } - /// Read bytes from a stream and discard them. - async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { - let mut nread = 0; - let mut saw_end = false; - - // TODO: Optimize by reading more than one byte at a time. - for _ in 0..nelem { - let (num, end) = self.read(&mut [0]).await?; - nread += num; - if end { - saw_end = true; - break; - } - } - - Ok((nread, saw_end)) - } - /// Repeatedly write a byte to a stream. async fn write_repeated(&mut self, byte: u8, nelem: u64) -> Result { let mut nwritten = 0; @@ -128,27 +150,29 @@ pub trait WasiStream: Send + Sync { Ok(nwritten) } - /// Return the number of bytes that may be read without blocking. - async fn num_ready_bytes(&self) -> Result { - Ok(0) - } - - /// Test whether this stream is readable. - async fn readable(&self) -> Result<(), Error>; - /// Test whether this stream is writeable. async fn writable(&self) -> Result<(), Error>; } pub trait TableStreamExt { - fn get_stream(&self, fd: u32) -> Result<&dyn WasiStream, Error>; - fn get_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; + fn get_input_stream(&self, fd: u32) -> Result<&dyn InputStream, Error>; + fn get_input_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; + + fn get_output_stream(&self, fd: u32) -> Result<&dyn OutputStream, Error>; + fn get_output_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; } impl TableStreamExt for crate::table::Table { - fn get_stream(&self, fd: u32) -> Result<&dyn WasiStream, Error> { - self.get::>(fd).map(|f| f.as_ref()) + fn get_input_stream(&self, fd: u32) -> Result<&dyn InputStream, Error> { + self.get::>(fd).map(|f| f.as_ref()) + } + fn get_input_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { + self.get_mut::>(fd) + } + + fn get_output_stream(&self, fd: u32) -> Result<&dyn OutputStream, Error> { + self.get::>(fd).map(|f| f.as_ref()) } - fn get_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { - self.get_mut::>(fd) + fn get_output_stream_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { + self.get_mut::>(fd) } } diff --git a/wasi-common/src/tcp_listener.rs b/wasi-common/src/tcp_listener.rs index b182453b1a8f..337567a838dc 100644 --- a/wasi-common/src/tcp_listener.rs +++ b/wasi-common/src/tcp_listener.rs @@ -3,7 +3,7 @@ use crate::connection::WasiConnection; use crate::Error; use crate::WasiListener; -use crate::WasiStream; +use crate::{InputStream, OutputStream}; use std::any::Any; use std::net::SocketAddr; @@ -15,7 +15,15 @@ pub trait WasiTcpListener: Send + Sync { async fn accept( &mut self, nonblocking: bool, - ) -> Result<(Box, Box, SocketAddr), Error>; + ) -> Result< + ( + Box, + Box, + Box, + SocketAddr, + ), + Error, + >; fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>;