Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host-wasmtime-rust: import functions are able to Trap execution #388

Merged
merged 11 commits into from
Oct 21, 2022
352 changes: 26 additions & 326 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ pulldown-cmark = { version = "0.8", default-features = false }
clap = { version = "4.0.9", features = ["derive"] }
env_logger = "0.9.1"

wasmtime = { git = "https://github.com/bytecodealliance/wasmtime", branch = "main" , features = ["component-model"] }
wasmtime-wasi = { git = "https://github.com/bytecodealliance/wasmtime", branch = "main" }
wasmtime-environ = { git = "https://github.com/bytecodealliance/wasmtime", branch = "main" }
wasmtime = { git = "https://github.com/bytecodealliance/wasmtime", features = ["component-model"] }
wasmtime-environ = { git = "https://github.com/bytecodealliance/wasmtime" }
wasmprinter = "0.2.41"
wasmparser = "0.92.0"
wasm-encoder = "0.18.0"
Expand Down
63 changes: 45 additions & 18 deletions crates/bindgen-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ pub struct Types {
type_info: HashMap<TypeId, TypeInfo>,
}

#[derive(Default, Clone, Copy)]
#[derive(Default, Clone, Copy, Debug)]
pub struct TypeInfo {
/// Whether or not this type is ever used (transitively) within the
/// parameter of a function.
Expand All @@ -179,6 +179,10 @@ pub struct TypeInfo {
/// result of a function.
pub result: bool,

/// Whether or not this type is ever used (transitively) within the
/// error case in the result of a function.
pub error: bool,

/// Whether or not this type (transitively) has a list.
pub has_list: bool,
}
Expand All @@ -187,6 +191,7 @@ impl std::ops::BitOrAssign for TypeInfo {
fn bitor_assign(&mut self, rhs: Self) {
self.param |= rhs.param;
self.result |= rhs.result;
self.error |= rhs.error;
self.has_list |= rhs.has_list;
}
}
Expand All @@ -198,10 +203,10 @@ impl Types {
}
for f in iface.functions.iter() {
for (_, ty) in f.params.iter() {
self.set_param_result_ty(iface, ty, true, false);
self.set_param_result_ty(iface, ty, true, false, false);
}
for ty in f.results.iter_types() {
self.set_param_result_ty(iface, ty, false, true);
self.set_param_result_ty(iface, ty, false, true, false);
}
}
}
Expand Down Expand Up @@ -281,56 +286,77 @@ impl Types {
}
}

fn set_param_result_id(&mut self, iface: &Interface, ty: TypeId, param: bool, result: bool) {
fn set_param_result_id(
&mut self,
iface: &Interface,
ty: TypeId,
param: bool,
result: bool,
error: bool,
) {
match &iface.types[ty].kind {
TypeDefKind::Record(r) => {
for field in r.fields.iter() {
self.set_param_result_ty(iface, &field.ty, param, result)
self.set_param_result_ty(iface, &field.ty, param, result, error)
}
}
TypeDefKind::Tuple(t) => {
for ty in t.types.iter() {
self.set_param_result_ty(iface, ty, param, result)
self.set_param_result_ty(iface, ty, param, result, error)
}
}
TypeDefKind::Flags(_) => {}
TypeDefKind::Enum(_) => {}
TypeDefKind::Variant(v) => {
for case in v.cases.iter() {
self.set_param_result_optional_ty(iface, case.ty.as_ref(), param, result)
self.set_param_result_optional_ty(iface, case.ty.as_ref(), param, result, error)
}
}
TypeDefKind::List(ty) | TypeDefKind::Type(ty) | TypeDefKind::Option(ty) => {
self.set_param_result_ty(iface, ty, param, result)
self.set_param_result_ty(iface, ty, param, result, error)
}
TypeDefKind::Result(r) => {
self.set_param_result_optional_ty(iface, r.ok.as_ref(), param, result);
self.set_param_result_optional_ty(iface, r.err.as_ref(), param, result);
self.set_param_result_optional_ty(iface, r.ok.as_ref(), param, result, error);
self.set_param_result_optional_ty(iface, r.err.as_ref(), param, result, result);
}
TypeDefKind::Union(u) => {
for case in u.cases.iter() {
self.set_param_result_ty(iface, &case.ty, param, result)
self.set_param_result_ty(iface, &case.ty, param, result, error)
}
}
TypeDefKind::Future(ty) => {
self.set_param_result_optional_ty(iface, ty.as_ref(), param, result)
self.set_param_result_optional_ty(iface, ty.as_ref(), param, result, error)
}
TypeDefKind::Stream(stream) => {
self.set_param_result_optional_ty(iface, stream.element.as_ref(), param, result);
self.set_param_result_optional_ty(iface, stream.end.as_ref(), param, result);
self.set_param_result_optional_ty(
iface,
stream.element.as_ref(),
param,
result,
error,
);
self.set_param_result_optional_ty(iface, stream.end.as_ref(), param, result, error);
}
}
}

fn set_param_result_ty(&mut self, iface: &Interface, ty: &Type, param: bool, result: bool) {
fn set_param_result_ty(
&mut self,
iface: &Interface,
ty: &Type,
param: bool,
result: bool,
error: bool,
) {
match ty {
Type::Id(id) => {
self.type_id_info(iface, *id);
let info = self.type_info.get_mut(id).unwrap();
if (param && !info.param) || (result && !info.result) {
if (param && !info.param) || (result && !info.result) || (error && !info.error) {
info.param = info.param || param;
info.result = info.result || result;
self.set_param_result_id(iface, *id, param, result);
info.error = info.error || error;
self.set_param_result_id(iface, *id, param, result, error);
}
}
_ => {}
Expand All @@ -343,9 +369,10 @@ impl Types {
ty: Option<&Type>,
param: bool,
result: bool,
error: bool,
) {
match ty {
Some(ty) => self.set_param_result_ty(iface, ty, param, result),
Some(ty) => self.set_param_result_ty(iface, ty, param, result, error),
None => (),
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/gen-host-wasmtime-rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ clap = { workspace = true, optional = true }
anyhow = { workspace = true }
test-helpers = { path = '../test-helpers' }
wasmtime = { workspace = true }
wasmtime-wasi = { workspace = true }
wit-bindgen-host-wasmtime-rust = { workspace = true, features = ['tracing'] }

tokio = { version = "1", features = ["full"] }
Expand Down
101 changes: 94 additions & 7 deletions crates/gen-host-wasmtime-rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl WorldGenerator for Wasmtime {
fn import(&mut self, name: &str, iface: &Interface, _files: &mut Files) {
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::Owned);
gen.types();
gen.generate_from_error_impls();
gen.generate_add_to_linker(name);

let snake = name.to_snake_case();
Expand All @@ -77,6 +78,7 @@ impl WorldGenerator for Wasmtime {
fn export(&mut self, name: &str, iface: &Interface, _files: &mut Files) {
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
gen.types();
gen.generate_from_error_impls();

let camel = name.to_upper_camel_case();
uwriteln!(gen.src, "pub struct {camel} {{");
Expand Down Expand Up @@ -312,6 +314,27 @@ impl<'a> InterfaceGenerator<'a> {
}
}

fn special_case_host_error(&self, results: &Results) -> Option<&Result_> {
// We only support the wit_bindgen_host_wasmtime_rust::Error case when
// a function has just one result, which is itself a `result<a, e>`, and the
// `e` is *not* a primitive (i.e. defined in std) type.
let mut i = results.iter_types();
if i.len() == 1 {
match i.next().unwrap() {
Type::Id(id) => match &self.iface.types[*id].kind {
TypeDefKind::Result(r) => match r.err {
Some(Type::Id(_)) => Some(&r),
_ => None,
},
_ => None,
},
_ => None,
}
} else {
None
}
}

fn generate_add_to_linker(&mut self, name: &str) {
let camel = name.to_upper_camel_case();

Expand All @@ -327,12 +350,34 @@ impl<'a> InterfaceGenerator<'a> {
fnsig.private = true;
fnsig.self_arg = Some("&mut self".to_string());

// These trait method args used to be TypeMode::LeafBorrowed, but wasmtime
// Lift is not impled for borrowed types, so I don't think we can
// support that anymore?
self.print_docs_and_params(func, TypeMode::Owned, &fnsig);
self.push_str(" -> ");
self.print_result_ty(&func.results, TypeMode::Owned);

if let Some(r) = self.special_case_host_error(&func.results).cloned() {
// Functions which have a single result `result<ok,err>` get special
// cased to use the host_wasmtime_rust::Error<err>, making it possible
// for them to trap or use `?` to propogate their errors
self.push_str("wit_bindgen_host_wasmtime_rust::Result<");
if let Some(ok) = r.ok {
self.print_ty(&ok, TypeMode::Owned);
} else {
self.push_str("()");
}
self.push_str(",");
if let Some(err) = r.err {
self.print_ty(&err, TypeMode::Owned);
} else {
self.push_str("()");
}
self.push_str(">");
} else {
// All other functions get their return values wrapped in an anyhow::Result.
// Returning the anyhow::Error case can be used to trap.
self.push_str("anyhow::Result<");
self.print_result_ty(&func.results, TypeMode::Owned);
self.push_str(">");
}

self.push_str(";\n");
}
uwriteln!(self.src, "}}");
Expand Down Expand Up @@ -420,10 +465,22 @@ impl<'a> InterfaceGenerator<'a> {
} else {
uwrite!(self.src, ");\n");
}
if func.results.iter_types().len() == 1 {
uwrite!(self.src, "Ok((r,))\n");

if self.special_case_host_error(&func.results).is_some() {
uwrite!(
self.src,
"match r {{
Ok(a) => Ok((Ok(a),)),
Err(e) => match e.downcast() {{
Ok(api_error) => Ok((Err(api_error),)),
Err(anyhow_error) => Err(anyhow_error),
}}
}}"
);
} else if func.results.iter_types().len() == 1 {
uwrite!(self.src, "Ok((r?,))\n");
} else {
uwrite!(self.src, "Ok(r)\n");
uwrite!(self.src, "r\n");
}

if self.gen.opts.async_ {
Expand Down Expand Up @@ -553,6 +610,36 @@ impl<'a> InterfaceGenerator<'a> {
// End function body
self.src.push_str("}\n");
}

fn generate_from_error_impls(&mut self) {
for (id, ty) in self.iface.types.iter() {
if ty.name.is_none() {
continue;
}
let info = self.info(id);
if info.error {
for (name, mode) in self.modes_of(id) {
let name = name.to_upper_camel_case();
if self.lifetime_for(&info, mode).is_some() {
continue;
}
self.push_str("impl From<");
self.push_str(&name);
self.push_str("> for wit_bindgen_host_wasmtime_rust::Error<");
self.push_str(&name);
self.push_str("> {\n");
self.push_str("fn from(e: ");
self.push_str(&name);
self.push_str(") -> wit_bindgen_host_wasmtime_rust::Error::< ");
self.push_str(&name);
self.push_str("> {\n");
self.push_str("wit_bindgen_host_wasmtime_rust::Error::new(e)\n");
self.push_str("}\n");
self.push_str("}\n");
}
}
}
}
}

impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> {
Expand Down
3 changes: 2 additions & 1 deletion crates/gen-host-wasmtime-rust/tests/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ wit_bindgen_host_wasmtime_rust::generate!({
pub struct TestWasi;

impl testwasi::Testwasi for TestWasi {
fn log(&mut self, bytes: Vec<u8>) {
fn log(&mut self, bytes: Vec<u8>) -> Result<()> {
match std::str::from_utf8(&bytes) {
Ok(s) => print!("{}", s),
Err(_) => println!("\nbinary: {:?}", bytes),
}
Ok(())
}
}
Loading