From 22a08628e53218991121d5f806aaf3ea03065259 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 22 Nov 2022 09:03:48 -0800 Subject: [PATCH 01/13] Import Wasmtime support from the `wit-bindgen` repo This commit imports the `wit-bindgen-gen-host-wasmtime-rust` crate from the `wit-bindgen` repository into the upstream Wasmtime repository. I've chosen to not import the full history here since the crate is relatively small and doesn't have a ton of complexity. While the history of the crate is quite long the current iteration of the crate's history is relatively short so there's not a ton of import there anyway. The thinking is that this can now continue to evolve in-tree. --- Cargo.lock | 42 ++ Cargo.toml | 3 + crates/wiggle/generate/Cargo.toml | 2 +- crates/wit-bindgen/Cargo.toml | 13 + crates/wit-bindgen/src/lib.rs | 1138 +++++++++++++++++++++++++++++ crates/wit-bindgen/src/rust.rs | 412 +++++++++++ crates/wit-bindgen/src/source.rs | 130 ++++ crates/wit-bindgen/src/types.rs | 207 ++++++ supply-chain/audits.toml | 6 + 9 files changed, 1952 insertions(+), 1 deletion(-) create mode 100644 crates/wit-bindgen/Cargo.toml create mode 100644 crates/wit-bindgen/src/lib.rs create mode 100644 crates/wit-bindgen/src/rust.rs create mode 100644 crates/wit-bindgen/src/source.rs create mode 100644 crates/wit-bindgen/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index e25b37e05aff..09c827339710 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2235,6 +2235,17 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffade02495f22453cd593159ea2f59827aae7f53fa8323f756799b670881dcf8" +dependencies = [ + "bitflags", + "memchr", + "unicase", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2969,6 +2980,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.8" @@ -3500,6 +3520,7 @@ dependencies = [ "quote", "syn", "wasmtime-component-util", + "wasmtime-wit-bindgen", ] [[package]] @@ -3757,6 +3778,14 @@ dependencies = [ "winch-codegen", ] +[[package]] +name = "wasmtime-wit-bindgen" +version = "4.0.0" +dependencies = [ + "heck", + "wit-parser", +] + [[package]] name = "wast" version = "35.0.2" @@ -3986,6 +4015,19 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "wit-parser" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "893834cffb239f88413eead7cf91862a6f24c2233afae15d7808256d8c58f91e" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "pulldown-cmark", + "unicode-xid", +] + [[package]] name = "witx" version = "0.9.1" diff --git a/Cargo.toml b/Cargo.toml index 9c5071c447ab..7f9cf233a7c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,6 +133,7 @@ wasi-tokio = { path = "crates/wasi-common/tokio", version = "=5.0.0" } wasi-cap-std-sync = { path = "crates/wasi-common/cap-std-sync", version = "=5.0.0" } wasmtime-fuzzing = { path = "crates/fuzzing" } wasmtime-jit-icache-coherence = { path = "crates/jit-icache-coherence", version = "=5.0.0" } +wasmtime-wit-bindgen = { path = "crates/wit-bindgen", version = "=5.0.0" } cranelift-wasm = { path = "cranelift/wasm", version = "0.92.0" } cranelift-codegen = { path = "cranelift/codegen", version = "0.92.0" } @@ -162,6 +163,7 @@ wasmprinter = "0.2.44" wasm-encoder = "0.20.0" wasm-smith = "0.11.9" wasm-mutate = "0.2.12" +wit-parser = "0.3" windows-sys = "0.42.0" env_logger = "0.9" rustix = "0.36.0" @@ -179,6 +181,7 @@ tracing = "0.1.26" bitflags = "1.2" thiserror = "1.0.15" async-trait = "0.1.42" +heck = "0.4" [features] default = [ diff --git a/crates/wiggle/generate/Cargo.toml b/crates/wiggle/generate/Cargo.toml index 2a84e3262ad7..b26054bd3030 100644 --- a/crates/wiggle/generate/Cargo.toml +++ b/crates/wiggle/generate/Cargo.toml @@ -17,7 +17,7 @@ include = ["src/**/*", "README.md", "LICENSE"] witx = { version = "0.9.1", path = "../../wasi-common/WASI/tools/witx" } quote = "1.0" proc-macro2 = "1.0" -heck = "0.4" +heck = { workspace = true } anyhow = { workspace = true } syn = { version = "1.0", features = ["full"] } shellexpand = "2.0" diff --git a/crates/wit-bindgen/Cargo.toml b/crates/wit-bindgen/Cargo.toml new file mode 100644 index 000000000000..7511f7b1cff0 --- /dev/null +++ b/crates/wit-bindgen/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "wasmtime-wit-bindgen" +version.workspace = true +authors.workspace = true +description = "Internal `*.wit` support for the `wasmtime` crate's macros" +license = "Apache-2.0 WITH LLVM-exception" +repository = "https://github.com/bytecodealliance/wasmtime" +documentation = "https://docs.rs/wasmtime-wit-bindgen/" +edition.workspace = true + +[dependencies] +heck = { workspace = true } +wit-parser = { workspace = true } diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs new file mode 100644 index 000000000000..2cfa960e02b7 --- /dev/null +++ b/crates/wit-bindgen/src/lib.rs @@ -0,0 +1,1138 @@ +use crate::rust::{to_rust_ident, RustGenerator, TypeMode}; +use crate::types::{TypeInfo, Types}; +use heck::*; +use std::collections::BTreeMap; +use std::fmt::Write as _; +use std::io::{Read, Write}; +use std::mem; +use std::process::{Command, Stdio}; +use wit_parser::*; + +macro_rules! uwrite { + ($dst:expr, $($arg:tt)*) => { + write!($dst, $($arg)*).unwrap() + }; +} + +macro_rules! uwriteln { + ($dst:expr, $($arg:tt)*) => { + writeln!($dst, $($arg)*).unwrap() + }; +} + +mod rust; +mod source; +mod types; +use source::Source; + +#[derive(Default)] +struct Wasmtime { + src: Source, + opts: Opts, + imports: Vec, + exports: Exports, +} + +#[derive(Default)] +struct Exports { + fields: BTreeMap, + funcs: Vec, +} + +#[derive(Default, Debug, Clone)] +#[cfg_attr(feature = "clap", derive(clap::Args))] +pub struct Opts { + /// Whether or not `rustfmt` is executed to format generated code. + #[cfg_attr(feature = "clap", arg(long))] + pub rustfmt: bool, + + /// Whether or not to emit `tracing` macro calls on function entry/exit. + #[cfg_attr(feature = "clap", arg(long))] + pub tracing: bool, + + /// Whether or not to use async rust functions and traits. + #[cfg_attr(feature = "clap", arg(long = "async"))] + pub async_: bool, +} + +impl Opts { + pub fn generate(&self, world: &World) -> String { + let mut r = Wasmtime::default(); + r.opts = self.clone(); + r.generate(world) + } +} + +impl Wasmtime { + fn generate(&mut self, world: &World) -> String { + for (name, import) in world.imports.iter() { + self.import(name, import); + } + for (name, export) in world.exports.iter() { + self.export(name, export); + } + if let Some(iface) = &world.default { + self.export_default(&world.name, iface); + } + self.finish(world) + } + + fn import(&mut self, name: &str, iface: &Interface) { + let mut gen = InterfaceGenerator::new(self, iface, TypeMode::Owned); + gen.types(); + gen.generate_from_error_impls(); + gen.generate_add_to_linker(name); + + let snake = name.to_snake_case(); + let module = &gen.src[..]; + + uwriteln!( + self.src, + " + #[allow(clippy::all)] + pub mod {snake} {{ + #[allow(unused_imports)] + use wit_bindgen_host_wasmtime_rust::{{wasmtime, anyhow}}; + + {module} + }} + " + ); + + self.imports.push(snake); + } + + fn export(&mut self, name: &str, iface: &Interface) { + let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a")); + gen.types(); + gen.generate_from_error_impls(); + + let camel = name.to_upper_camel_case(); + uwriteln!(gen.src, "pub struct {camel} {{"); + for func in iface.functions.iter() { + uwriteln!( + gen.src, + "{}: wasmtime::component::Func,", + func.name.to_snake_case() + ); + } + uwriteln!(gen.src, "}}"); + + uwriteln!(gen.src, "impl {camel} {{"); + uwrite!( + gen.src, + " + pub fn new( + __exports: &mut wasmtime::component::ExportInstance<'_, '_>, + ) -> anyhow::Result<{camel}> {{ + " + ); + let fields = gen.extract_typed_functions(); + for (name, getter) in fields.iter() { + uwriteln!(gen.src, "let {name} = {getter};"); + } + uwriteln!(gen.src, "Ok({camel} {{"); + for (name, _) in fields.iter() { + uwriteln!(gen.src, "{name},"); + } + uwriteln!(gen.src, "}})"); + uwriteln!(gen.src, "}}"); + for func in iface.functions.iter() { + gen.define_rust_guest_export(Some(name), func); + } + uwriteln!(gen.src, "}}"); + + let snake = name.to_snake_case(); + let module = &gen.src[..]; + + uwriteln!( + self.src, + " + #[allow(clippy::all)] + pub mod {snake} {{ + #[allow(unused_imports)] + use wit_bindgen_host_wasmtime_rust::{{wasmtime, anyhow}}; + + {module} + }} + " + ); + + let getter = format!( + "\ + {snake}::{camel}::new( + &mut __exports.instance(\"{name}\") + .ok_or_else(|| anyhow::anyhow!(\"exported instance `{name}` not present\"))? + )?\ + " + ); + let prev = self + .exports + .fields + .insert(snake.clone(), (format!("{snake}::{camel}"), getter)); + assert!(prev.is_none()); + self.exports.funcs.push(format!( + " + pub fn {snake}(&self) -> &{snake}::{camel} {{ + &self.{snake} + }} + " + )); + } + + fn export_default(&mut self, _name: &str, iface: &Interface) { + let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a")); + gen.types(); + gen.generate_from_error_impls(); + let fields = gen.extract_typed_functions(); + for (name, getter) in fields { + let prev = gen + .gen + .exports + .fields + .insert(name, ("wasmtime::component::Func".to_string(), getter)); + assert!(prev.is_none()); + } + + for func in iface.functions.iter() { + let prev = mem::take(&mut gen.src); + gen.define_rust_guest_export(None, func); + let func = mem::replace(&mut gen.src, prev); + gen.gen.exports.funcs.push(func.to_string()); + } + + let src = gen.src; + self.src.push_str(&src); + } + + fn finish(&mut self, world: &World) -> String { + let camel = world.name.to_upper_camel_case(); + uwriteln!(self.src, "pub struct {camel} {{"); + for (name, (ty, _)) in self.exports.fields.iter() { + uwriteln!(self.src, "{name}: {ty},"); + } + self.src.push_str("}\n"); + + let (async_, async__, send, await_) = if self.opts.async_ { + ("async", "_async", ":Send", ".await") + } else { + ("", "", "", "") + }; + + uwriteln!( + self.src, + " + impl {camel} {{ + /// Instantiates the provided `module` using the specified + /// parameters, wrapping up the result in a structure that + /// translates between wasm and the host. + pub {async_} fn instantiate{async__}( + mut store: impl wasmtime::AsContextMut, + component: &wasmtime::component::Component, + linker: &wasmtime::component::Linker, + ) -> anyhow::Result<(Self, wasmtime::component::Instance)> {{ + let instance = linker.instantiate{async__}(&mut store, component){await_}?; + Ok((Self::new(store, &instance)?, instance)) + }} + + /// Low-level creation wrapper for wrapping up the exports + /// of the `instance` provided in this structure of wasm + /// exports. + /// + /// This function will extract exports from the `instance` + /// defined within `store` and wrap them all up in the + /// returned structure which can be used to interact with + /// the wasm module. + pub fn new( + mut store: impl wasmtime::AsContextMut, + instance: &wasmtime::component::Instance, + ) -> anyhow::Result {{ + let mut store = store.as_context_mut(); + let mut exports = instance.exports(&mut store); + let mut __exports = exports.root(); + ", + ); + for (name, (_, get)) in self.exports.fields.iter() { + uwriteln!(self.src, "let {name} = {get};"); + } + uwriteln!(self.src, "Ok({camel} {{"); + for (name, _) in self.exports.fields.iter() { + uwriteln!(self.src, "{name},"); + } + uwriteln!(self.src, "}})"); + uwriteln!(self.src, "}}"); + + for func in self.exports.funcs.iter() { + self.src.push_str(func); + } + + uwriteln!(self.src, "}}"); + + let mut src = mem::take(&mut self.src); + if self.opts.rustfmt { + let mut child = Command::new("rustfmt") + .arg("--edition=2018") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("failed to spawn `rustfmt`"); + child + .stdin + .take() + .unwrap() + .write_all(src.as_bytes()) + .unwrap(); + src.as_mut_string().truncate(0); + child + .stdout + .take() + .unwrap() + .read_to_string(src.as_mut_string()) + .unwrap(); + let status = child.wait().unwrap(); + assert!(status.success()); + } + + src.into() + } +} + +struct InterfaceGenerator<'a> { + src: Source, + gen: &'a mut Wasmtime, + iface: &'a Interface, + default_param_mode: TypeMode, + types: Types, +} + +impl<'a> InterfaceGenerator<'a> { + fn new( + gen: &'a mut Wasmtime, + iface: &'a Interface, + default_param_mode: TypeMode, + ) -> InterfaceGenerator<'a> { + let mut types = Types::default(); + types.analyze(iface); + InterfaceGenerator { + src: Source::default(), + gen, + iface, + types, + default_param_mode, + } + } + + fn types(&mut self) { + for (id, ty) in self.iface.types.iter() { + let name = match &ty.name { + Some(name) => name, + None => continue, + }; + match &ty.kind { + TypeDefKind::Record(record) => self.type_record(id, name, record, &ty.docs), + TypeDefKind::Flags(flags) => self.type_flags(id, name, flags, &ty.docs), + TypeDefKind::Tuple(tuple) => self.type_tuple(id, name, tuple, &ty.docs), + TypeDefKind::Enum(enum_) => self.type_enum(id, name, enum_, &ty.docs), + TypeDefKind::Variant(variant) => self.type_variant(id, name, variant, &ty.docs), + TypeDefKind::Option(t) => self.type_option(id, name, t, &ty.docs), + TypeDefKind::Result(r) => self.type_result(id, name, r, &ty.docs), + TypeDefKind::Union(u) => self.type_union(id, name, u, &ty.docs), + TypeDefKind::List(t) => self.type_list(id, name, t, &ty.docs), + TypeDefKind::Type(t) => self.type_alias(id, name, t, &ty.docs), + TypeDefKind::Future(_) => todo!("generate for future"), + TypeDefKind::Stream(_) => todo!("generate for stream"), + } + } + } + + fn type_record(&mut self, id: TypeId, _name: &str, record: &Record, docs: &Docs) { + let info = self.info(id); + for (name, mode) in self.modes_of(id) { + let lt = self.lifetime_for(&info, mode); + self.rustdoc(docs); + + self.push_str("#[derive(wasmtime::component::ComponentType)]\n"); + if lt.is_none() { + self.push_str("#[derive(wasmtime::component::Lift)]\n"); + } + self.push_str("#[derive(wasmtime::component::Lower)]\n"); + self.push_str("#[component(record)]\n"); + + if !info.has_list { + self.push_str("#[derive(Copy, Clone)]\n"); + } else { + self.push_str("#[derive(Clone)]\n"); + } + self.push_str(&format!("pub struct {}", name)); + self.print_generics(lt); + self.push_str(" {\n"); + for field in record.fields.iter() { + self.rustdoc(&field.docs); + self.push_str(&format!("#[component(name = \"{}\")]\n", field.name)); + self.push_str("pub "); + self.push_str(&to_rust_ident(&field.name)); + self.push_str(": "); + self.print_ty(&field.ty, mode); + self.push_str(",\n"); + } + self.push_str("}\n"); + + self.push_str("impl"); + self.print_generics(lt); + self.push_str(" core::fmt::Debug for "); + self.push_str(&name); + self.print_generics(lt); + self.push_str(" {\n"); + self.push_str( + "fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + self.push_str(&format!("f.debug_struct(\"{}\")", name)); + for field in record.fields.iter() { + self.push_str(&format!( + ".field(\"{}\", &self.{})", + field.name, + to_rust_ident(&field.name) + )); + } + self.push_str(".finish()\n"); + self.push_str("}\n"); + self.push_str("}\n"); + + if info.error { + self.push_str("impl"); + self.print_generics(lt); + self.push_str(" core::fmt::Display for "); + self.push_str(&name); + self.print_generics(lt); + self.push_str(" {\n"); + self.push_str( + "fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + self.push_str("write!(f, \"{:?}\", self)\n"); + self.push_str("}\n"); + self.push_str("}\n"); + self.push_str("impl std::error::Error for "); + self.push_str(&name); + self.push_str("{}\n"); + } + } + } + + fn type_tuple(&mut self, id: TypeId, _name: &str, tuple: &Tuple, docs: &Docs) { + let info = self.info(id); + for (name, mode) in self.modes_of(id) { + let lt = self.lifetime_for(&info, mode); + self.rustdoc(docs); + self.push_str(&format!("pub type {}", name)); + self.print_generics(lt); + self.push_str(" = ("); + for ty in tuple.types.iter() { + self.print_ty(ty, mode); + self.push_str(","); + } + self.push_str(");\n"); + } + } + + fn type_flags(&mut self, _id: TypeId, name: &str, flags: &Flags, docs: &Docs) { + self.rustdoc(docs); + self.src.push_str("wasmtime::component::flags!(\n"); + self.src + .push_str(&format!("{} {{\n", name.to_upper_camel_case())); + for flag in flags.flags.iter() { + // TODO wasmtime-component-macro doesnt support docs for flags rn + uwrite!( + self.src, + "#[component(name=\"{}\")] const {};\n", + flag.name, + flag.name.to_shouty_snake_case() + ); + } + self.src.push_str("}\n"); + self.src.push_str(");\n\n"); + } + + fn type_variant(&mut self, id: TypeId, _name: &str, variant: &Variant, docs: &Docs) { + self.print_rust_enum( + id, + variant.cases.iter().map(|c| { + ( + c.name.to_upper_camel_case(), + Some(c.name.clone()), + &c.docs, + c.ty.as_ref(), + ) + }), + docs, + "variant", + ); + } + + fn type_union(&mut self, id: TypeId, _name: &str, union: &Union, docs: &Docs) { + self.print_rust_enum( + id, + std::iter::zip(self.union_case_names(union), &union.cases) + .map(|(name, case)| (name, None, &case.docs, Some(&case.ty))), + docs, + "union", + ); + } + + fn type_option(&mut self, id: TypeId, _name: &str, payload: &Type, docs: &Docs) { + let info = self.info(id); + + for (name, mode) in self.modes_of(id) { + self.rustdoc(docs); + let lt = self.lifetime_for(&info, mode); + self.push_str(&format!("pub type {}", name)); + self.print_generics(lt); + self.push_str("= Option<"); + self.print_ty(payload, mode); + self.push_str(">;\n"); + } + } + + fn print_rust_enum<'b>( + &mut self, + id: TypeId, + cases: impl IntoIterator, &'b Docs, Option<&'b Type>)> + Clone, + docs: &Docs, + derive_component: &str, + ) where + Self: Sized, + { + let info = self.info(id); + + for (name, mode) in self.modes_of(id) { + let name = name.to_upper_camel_case(); + + self.rustdoc(docs); + let lt = self.lifetime_for(&info, mode); + self.push_str("#[derive(wasmtime::component::ComponentType)]\n"); + if lt.is_none() { + self.push_str("#[derive(wasmtime::component::Lift)]\n"); + } + self.push_str("#[derive(wasmtime::component::Lower)]\n"); + self.push_str(&format!("#[component({})]\n", derive_component)); + if !info.has_list { + self.push_str("#[derive(Clone, Copy)]\n"); + } else { + self.push_str("#[derive(Clone)]\n"); + } + self.push_str(&format!("pub enum {name}")); + self.print_generics(lt); + self.push_str("{\n"); + for (case_name, component_name, docs, payload) in cases.clone() { + self.rustdoc(docs); + if let Some(n) = component_name { + self.push_str(&format!("#[component(name = \"{}\")] ", n)); + } + self.push_str(&case_name); + if let Some(ty) = payload { + self.push_str("("); + self.print_ty(ty, mode); + self.push_str(")") + } + self.push_str(",\n"); + } + self.push_str("}\n"); + + self.print_rust_enum_debug( + id, + mode, + &name, + cases + .clone() + .into_iter() + .map(|(name, _attr, _docs, ty)| (name, ty)), + ); + + if info.error { + self.push_str("impl"); + self.print_generics(lt); + self.push_str(" core::fmt::Display for "); + self.push_str(&name); + self.print_generics(lt); + self.push_str(" {\n"); + self.push_str( + "fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + self.push_str("write!(f, \"{:?}\", self)"); + self.push_str("}\n"); + self.push_str("}\n"); + self.push_str("\n"); + + self.push_str("impl"); + self.print_generics(lt); + self.push_str(" std::error::Error for "); + self.push_str(&name); + self.print_generics(lt); + self.push_str(" {}\n"); + } + } + } + + fn print_rust_enum_debug<'b>( + &mut self, + id: TypeId, + mode: TypeMode, + name: &str, + cases: impl IntoIterator)>, + ) where + Self: Sized, + { + let info = self.info(id); + let lt = self.lifetime_for(&info, mode); + self.push_str("impl"); + self.print_generics(lt); + self.push_str(" core::fmt::Debug for "); + self.push_str(name); + self.print_generics(lt); + self.push_str(" {\n"); + self.push_str("fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n"); + self.push_str("match self {\n"); + for (case_name, payload) in cases { + self.push_str(name); + self.push_str("::"); + self.push_str(&case_name); + if payload.is_some() { + self.push_str("(e)"); + } + self.push_str(" => {\n"); + self.push_str(&format!("f.debug_tuple(\"{}::{}\")", name, case_name)); + if payload.is_some() { + self.push_str(".field(e)"); + } + self.push_str(".finish()\n"); + self.push_str("}\n"); + } + self.push_str("}\n"); + self.push_str("}\n"); + self.push_str("}\n"); + } + + fn type_result(&mut self, id: TypeId, _name: &str, result: &Result_, docs: &Docs) { + let info = self.info(id); + + for (name, mode) in self.modes_of(id) { + self.rustdoc(docs); + let lt = self.lifetime_for(&info, mode); + self.push_str(&format!("pub type {}", name)); + self.print_generics(lt); + self.push_str("= Result<"); + self.print_optional_ty(result.ok.as_ref(), mode); + self.push_str(","); + self.print_optional_ty(result.err.as_ref(), mode); + self.push_str(">;\n"); + } + } + + fn type_enum(&mut self, id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { + let info = self.info(id); + + let name = name.to_upper_camel_case(); + self.rustdoc(docs); + self.push_str("#[derive(wasmtime::component::ComponentType)]\n"); + self.push_str("#[derive(wasmtime::component::Lift)]\n"); + self.push_str("#[derive(wasmtime::component::Lower)]\n"); + self.push_str("#[component(enum)]\n"); + self.push_str("#[derive(Clone, Copy, PartialEq, Eq)]\n"); + self.push_str(&format!("pub enum {} {{\n", name.to_upper_camel_case())); + for case in enum_.cases.iter() { + self.rustdoc(&case.docs); + self.push_str(&format!("#[component(name = \"{}\")]", case.name)); + self.push_str(&case.name.to_upper_camel_case()); + self.push_str(",\n"); + } + self.push_str("}\n"); + + // Auto-synthesize an implementation of the standard `Error` trait for + // error-looking types based on their name. + if info.error { + self.push_str("impl "); + self.push_str(&name); + self.push_str("{\n"); + + self.push_str("pub fn name(&self) -> &'static str {\n"); + self.push_str("match self {\n"); + for case in enum_.cases.iter() { + self.push_str(&name); + self.push_str("::"); + self.push_str(&case.name.to_upper_camel_case()); + self.push_str(" => \""); + self.push_str(case.name.as_str()); + self.push_str("\",\n"); + } + self.push_str("}\n"); + self.push_str("}\n"); + + self.push_str("pub fn message(&self) -> &'static str {\n"); + self.push_str("match self {\n"); + for case in enum_.cases.iter() { + self.push_str(&name); + self.push_str("::"); + self.push_str(&case.name.to_upper_camel_case()); + self.push_str(" => \""); + if let Some(contents) = &case.docs.contents { + self.push_str(contents.trim()); + } + self.push_str("\",\n"); + } + self.push_str("}\n"); + self.push_str("}\n"); + + self.push_str("}\n"); + + self.push_str("impl core::fmt::Debug for "); + self.push_str(&name); + self.push_str( + "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + self.push_str("f.debug_struct(\""); + self.push_str(&name); + self.push_str("\")\n"); + self.push_str(".field(\"code\", &(*self as i32))\n"); + self.push_str(".field(\"name\", &self.name())\n"); + self.push_str(".field(\"message\", &self.message())\n"); + self.push_str(".finish()\n"); + self.push_str("}\n"); + self.push_str("}\n"); + + self.push_str("impl core::fmt::Display for "); + self.push_str(&name); + self.push_str( + "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + self.push_str("write!(f, \"{} (error {})\", self.name(), *self as i32)"); + self.push_str("}\n"); + self.push_str("}\n"); + self.push_str("\n"); + self.push_str("impl std::error::Error for "); + self.push_str(&name); + self.push_str("{}\n"); + } else { + self.print_rust_enum_debug( + id, + TypeMode::Owned, + &name, + enum_ + .cases + .iter() + .map(|c| (c.name.to_upper_camel_case(), None)), + ) + } + } + + fn type_alias(&mut self, id: TypeId, _name: &str, ty: &Type, docs: &Docs) { + let info = self.info(id); + for (name, mode) in self.modes_of(id) { + self.rustdoc(docs); + self.push_str(&format!("pub type {}", name)); + let lt = self.lifetime_for(&info, mode); + self.print_generics(lt); + self.push_str(" = "); + self.print_ty(ty, mode); + self.push_str(";\n"); + } + } + + fn type_list(&mut self, id: TypeId, _name: &str, ty: &Type, docs: &Docs) { + let info = self.info(id); + for (name, mode) in self.modes_of(id) { + let lt = self.lifetime_for(&info, mode); + self.rustdoc(docs); + self.push_str(&format!("pub type {}", name)); + self.print_generics(lt); + self.push_str(" = "); + self.print_list(ty, mode); + self.push_str(";\n"); + } + } + + fn print_result_ty(&mut self, results: &Results, mode: TypeMode) { + match results { + Results::Named(rs) => match rs.len() { + 0 => self.push_str("()"), + 1 => self.print_ty(&rs[0].1, mode), + _ => { + self.push_str("("); + for (i, (_, ty)) in rs.iter().enumerate() { + if i > 0 { + self.push_str(", ") + } + self.print_ty(ty, mode) + } + self.push_str(")"); + } + }, + Results::Anon(ty) => self.print_ty(ty, mode), + } + } + + fn special_case_host_error(&self, results: &Results) -> Option<&Result_> { + // We only support the wit_bindgen_host_wasmtime_rust::Error case when + // a function has just one result, which is itself a `result`, and the + // `e` is *not* a primitive (i.e. defined in std) type. + let mut i = results.iter_types(); + if i.len() == 1 { + match i.next().unwrap() { + Type::Id(id) => match &self.iface.types[*id].kind { + TypeDefKind::Result(r) => match r.err { + Some(Type::Id(_)) => Some(&r), + _ => None, + }, + _ => None, + }, + _ => None, + } + } else { + None + } + } + + fn generate_add_to_linker(&mut self, name: &str) { + let camel = name.to_upper_camel_case(); + + if self.gen.opts.async_ { + uwriteln!(self.src, "#[wit_bindgen_host_wasmtime_rust::async_trait]") + } + // Generate the `pub trait` which represents the host functionality for + // this import. + uwriteln!(self.src, "pub trait {camel}: Sized {{"); + for func in self.iface.functions.iter() { + self.rustdoc(&func.docs); + + if self.gen.opts.async_ { + self.push_str("async "); + } + self.push_str("fn "); + self.push_str(&to_rust_ident(&func.name)); + self.push_str("(&mut self, "); + for (name, param) in func.params.iter() { + let name = to_rust_ident(name); + self.push_str(&name); + self.push_str(": "); + self.print_ty(param, TypeMode::Owned); + self.push_str(","); + } + self.push_str(")"); + self.push_str(" -> "); + + if let Some(r) = self.special_case_host_error(&func.results).cloned() { + // Functions which have a single result `result` get special + // cased to use the host_wasmtime_rust::Error, making it possible + // for them to trap or use `?` to propogate their errors + self.push_str("wit_bindgen_host_wasmtime_rust::Result<"); + if let Some(ok) = r.ok { + self.print_ty(&ok, TypeMode::Owned); + } else { + self.push_str("()"); + } + self.push_str(","); + if let Some(err) = r.err { + self.print_ty(&err, TypeMode::Owned); + } else { + self.push_str("()"); + } + self.push_str(">"); + } else { + // All other functions get their return values wrapped in an anyhow::Result. + // Returning the anyhow::Error case can be used to trap. + self.push_str("anyhow::Result<"); + self.print_result_ty(&func.results, TypeMode::Owned); + self.push_str(">"); + } + + self.push_str(";\n"); + } + uwriteln!(self.src, "}}"); + + let where_clause = if self.gen.opts.async_ { + format!("T: Send, U: {camel} + Send") + } else { + format!("U: {camel}") + }; + uwriteln!( + self.src, + " + pub fn add_to_linker( + linker: &mut wasmtime::component::Linker, + get: impl Fn(&mut T) -> &mut U + Send + Sync + Copy + 'static, + ) -> anyhow::Result<()> + where {where_clause}, + {{ + " + ); + uwriteln!(self.src, "let mut inst = linker.instance(\"{name}\")?;"); + for func in self.iface.functions.iter() { + uwrite!( + self.src, + "inst.{}(\"{}\", ", + if self.gen.opts.async_ { + "func_wrap_async" + } else { + "func_wrap" + }, + func.name + ); + self.generate_guest_import_closure(func); + uwriteln!(self.src, ")?;") + } + uwriteln!(self.src, "Ok(())"); + uwriteln!(self.src, "}}"); + } + + fn generate_guest_import_closure(&mut self, func: &Function) { + // Generate the closure that's passed to a `Linker`, the final piece of + // codegen here. + self.src + .push_str("move |mut caller: wasmtime::StoreContextMut<'_, T>, ("); + for (i, _param) in func.params.iter().enumerate() { + uwrite!(self.src, "arg{},", i); + } + self.src.push_str(") : ("); + for param in func.params.iter() { + // Lift is required to be impled for this type, so we can't use + // a borrowed type: + self.print_ty(¶m.1, TypeMode::Owned); + self.src.push_str(", "); + } + self.src.push_str(") |"); + if self.gen.opts.async_ { + self.src.push_str(" Box::new(async move { \n"); + } else { + self.src.push_str(" { \n"); + } + + if self.gen.opts.tracing { + self.src.push_str(&format!( + " + let span = wit_bindgen_host_wasmtime_rust::tracing::span!( + wit_bindgen_host_wasmtime_rust::tracing::Level::TRACE, + \"wit-bindgen guest import\", + module = \"{}\", + function = \"{}\", + ); + let _enter = span.enter(); + ", + self.iface.name, func.name, + )); + } + + self.src.push_str("let host = get(caller.data_mut());\n"); + + uwrite!(self.src, "let r = host.{}(", func.name.to_snake_case()); + for (i, _) in func.params.iter().enumerate() { + uwrite!(self.src, "arg{},", i); + } + if self.gen.opts.async_ { + uwrite!(self.src, ").await;\n"); + } else { + uwrite!(self.src, ");\n"); + } + + if self.special_case_host_error(&func.results).is_some() { + uwrite!( + self.src, + "match r {{ + Ok(a) => Ok((Ok(a),)), + Err(e) => match e.downcast() {{ + Ok(api_error) => Ok((Err(api_error),)), + Err(anyhow_error) => Err(anyhow_error), + }} + }}" + ); + } else if func.results.iter_types().len() == 1 { + uwrite!(self.src, "Ok((r?,))\n"); + } else { + uwrite!(self.src, "r\n"); + } + + if self.gen.opts.async_ { + // Need to close Box::new and async block + self.src.push_str("})"); + } else { + self.src.push_str("}"); + } + } + + fn extract_typed_functions(&mut self) -> Vec<(String, String)> { + let prev = mem::take(&mut self.src); + let mut ret = Vec::new(); + for func in self.iface.functions.iter() { + let snake = func.name.to_snake_case(); + uwrite!(self.src, "*__exports.typed_func::<("); + for (_, ty) in func.params.iter() { + self.print_ty(ty, TypeMode::AllBorrowed("'_")); + self.push_str(", "); + } + self.src.push_str("), ("); + for ty in func.results.iter_types() { + self.print_ty(ty, TypeMode::Owned); + self.push_str(", "); + } + self.src.push_str(")>(\""); + self.src.push_str(&func.name); + self.src.push_str("\")?.func()"); + + ret.push((snake, mem::take(&mut self.src).to_string())); + } + self.src = prev; + return ret; + } + + fn define_rust_guest_export(&mut self, ns: Option<&str>, func: &Function) { + let (async_, async__, await_) = if self.gen.opts.async_ { + ("async", "_async", ".await") + } else { + ("", "", "") + }; + + self.rustdoc(&func.docs); + uwrite!( + self.src, + "pub {async_} fn {}(&self, mut store: S, ", + func.name.to_snake_case(), + ); + for (i, param) in func.params.iter().enumerate() { + uwrite!(self.src, "arg{}: ", i); + self.print_ty(¶m.1, TypeMode::AllBorrowed("'_")); + self.push_str(","); + } + self.src.push_str(") -> anyhow::Result<"); + self.print_result_ty(&func.results, TypeMode::Owned); + + if self.gen.opts.async_ { + self.src + .push_str("> where ::Data: Send {\n"); + } else { + self.src.push_str("> {\n"); + } + + if self.gen.opts.tracing { + self.src.push_str(&format!( + " + let span = wit_bindgen_host_wasmtime_rust::tracing::span!( + wit_bindgen_host_wasmtime_rust::tracing::Level::TRACE, + \"wit-bindgen guest export\", + module = \"{}\", + function = \"{}\", + ); + let _enter = span.enter(); + ", + ns.unwrap_or("default"), + func.name, + )); + } + + self.src.push_str("let callee = unsafe {\n"); + self.src.push_str("wasmtime::component::TypedFunc::<("); + for (_, ty) in func.params.iter() { + self.print_ty(ty, TypeMode::AllBorrowed("'_")); + self.push_str(", "); + } + self.src.push_str("), ("); + for ty in func.results.iter_types() { + self.print_ty(ty, TypeMode::Owned); + self.push_str(", "); + } + uwriteln!( + self.src, + ")>::new_unchecked(self.{})", + func.name.to_snake_case() + ); + self.src.push_str("};\n"); + self.src.push_str("let ("); + for (i, _) in func.results.iter_types().enumerate() { + uwrite!(self.src, "ret{},", i); + } + uwrite!( + self.src, + ") = callee.call{async__}(store.as_context_mut(), (" + ); + for (i, _) in func.params.iter().enumerate() { + uwrite!(self.src, "arg{}, ", i); + } + uwriteln!(self.src, ")){await_}?;"); + + uwriteln!( + self.src, + "callee.post_return{async__}(store.as_context_mut()){await_}?;" + ); + + self.src.push_str("Ok("); + if func.results.iter_types().len() == 1 { + self.src.push_str("ret0"); + } else { + self.src.push_str("("); + for (i, _) in func.results.iter_types().enumerate() { + uwrite!(self.src, "ret{},", i); + } + self.src.push_str(")"); + } + self.src.push_str(")\n"); + + // End function body + self.src.push_str("}\n"); + } + + fn generate_from_error_impls(&mut self) { + for (id, ty) in self.iface.types.iter() { + if ty.name.is_none() { + continue; + } + let info = self.info(id); + if info.error { + for (name, mode) in self.modes_of(id) { + let name = name.to_upper_camel_case(); + if self.lifetime_for(&info, mode).is_some() { + continue; + } + self.push_str("impl From<"); + self.push_str(&name); + self.push_str("> for wit_bindgen_host_wasmtime_rust::Error<"); + self.push_str(&name); + self.push_str("> {\n"); + self.push_str("fn from(e: "); + self.push_str(&name); + self.push_str(") -> wit_bindgen_host_wasmtime_rust::Error::< "); + self.push_str(&name); + self.push_str("> {\n"); + self.push_str("wit_bindgen_host_wasmtime_rust::Error::new(e)\n"); + self.push_str("}\n"); + self.push_str("}\n"); + } + } + } + } + + fn rustdoc(&mut self, docs: &Docs) { + let docs = match &docs.contents { + Some(docs) => docs, + None => return, + }; + for line in docs.trim().lines() { + self.push_str("/// "); + self.push_str(line); + self.push_str("\n"); + } + } +} + +impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> { + fn iface(&self) -> &'a Interface { + self.iface + } + + fn default_param_mode(&self) -> TypeMode { + self.default_param_mode + } + + fn push_str(&mut self, s: &str) { + self.src.push_str(s); + } + + fn info(&self, ty: TypeId) -> TypeInfo { + self.types.get(ty) + } +} diff --git a/crates/wit-bindgen/src/rust.rs b/crates/wit-bindgen/src/rust.rs new file mode 100644 index 000000000000..d791478356ae --- /dev/null +++ b/crates/wit-bindgen/src/rust.rs @@ -0,0 +1,412 @@ +use crate::types::TypeInfo; +use heck::*; +use std::collections::HashMap; +use std::fmt::Write; +use wit_parser::*; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum TypeMode { + Owned, + AllBorrowed(&'static str), +} + +pub trait RustGenerator<'a> { + fn iface(&self) -> &'a Interface; + + fn push_str(&mut self, s: &str); + fn info(&self, ty: TypeId) -> TypeInfo; + fn default_param_mode(&self) -> TypeMode; + + fn print_ty(&mut self, ty: &Type, mode: TypeMode) { + match ty { + Type::Id(t) => self.print_tyid(*t, mode), + Type::Bool => self.push_str("bool"), + Type::U8 => self.push_str("u8"), + Type::U16 => self.push_str("u16"), + Type::U32 => self.push_str("u32"), + Type::U64 => self.push_str("u64"), + Type::S8 => self.push_str("i8"), + Type::S16 => self.push_str("i16"), + Type::S32 => self.push_str("i32"), + Type::S64 => self.push_str("i64"), + Type::Float32 => self.push_str("f32"), + Type::Float64 => self.push_str("f64"), + Type::Char => self.push_str("char"), + Type::String => match mode { + TypeMode::AllBorrowed(lt) => { + self.push_str("&"); + if lt != "'_" { + self.push_str(lt); + self.push_str(" "); + } + self.push_str("str"); + } + TypeMode::Owned => self.push_str("String"), + }, + } + } + + fn print_optional_ty(&mut self, ty: Option<&Type>, mode: TypeMode) { + match ty { + Some(ty) => self.print_ty(ty, mode), + None => self.push_str("()"), + } + } + + fn print_tyid(&mut self, id: TypeId, mode: TypeMode) { + let info = self.info(id); + let lt = self.lifetime_for(&info, mode); + let ty = &self.iface().types[id]; + if ty.name.is_some() { + let name = if lt.is_some() { + self.param_name(id) + } else { + self.result_name(id) + }; + self.push_str(&name); + + // If the type recursively owns data and it's a + // variant/record/list, then we need to place the + // lifetime parameter on the type as well. + if info.has_list && needs_generics(self.iface(), &ty.kind) { + self.print_generics(lt); + } + + return; + + fn needs_generics(iface: &Interface, ty: &TypeDefKind) -> bool { + match ty { + TypeDefKind::Variant(_) + | TypeDefKind::Record(_) + | TypeDefKind::Option(_) + | TypeDefKind::Result(_) + | TypeDefKind::Future(_) + | TypeDefKind::Stream(_) + | TypeDefKind::List(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Tuple(_) + | TypeDefKind::Union(_) => true, + TypeDefKind::Type(Type::Id(t)) => needs_generics(iface, &iface.types[*t].kind), + TypeDefKind::Type(Type::String) => true, + TypeDefKind::Type(_) => false, + } + } + } + + match &ty.kind { + TypeDefKind::List(t) => self.print_list(t, mode), + + TypeDefKind::Option(t) => { + self.push_str("Option<"); + self.print_ty(t, mode); + self.push_str(">"); + } + + TypeDefKind::Result(r) => { + self.push_str("Result<"); + self.print_optional_ty(r.ok.as_ref(), mode); + self.push_str(","); + self.print_optional_ty(r.err.as_ref(), mode); + self.push_str(">"); + } + + TypeDefKind::Variant(_) => panic!("unsupported anonymous variant"), + + // Tuple-like records are mapped directly to Rust tuples of + // types. Note the trailing comma after each member to + // appropriately handle 1-tuples. + TypeDefKind::Tuple(t) => { + self.push_str("("); + for ty in t.types.iter() { + self.print_ty(ty, mode); + self.push_str(","); + } + self.push_str(")"); + } + TypeDefKind::Record(_) => { + panic!("unsupported anonymous type reference: record") + } + TypeDefKind::Flags(_) => { + panic!("unsupported anonymous type reference: flags") + } + TypeDefKind::Enum(_) => { + panic!("unsupported anonymous type reference: enum") + } + TypeDefKind::Union(_) => { + panic!("unsupported anonymous type reference: union") + } + TypeDefKind::Future(ty) => { + self.push_str("Future<"); + self.print_optional_ty(ty.as_ref(), mode); + self.push_str(">"); + } + TypeDefKind::Stream(stream) => { + self.push_str("Stream<"); + self.print_optional_ty(stream.element.as_ref(), mode); + self.push_str(","); + self.print_optional_ty(stream.end.as_ref(), mode); + self.push_str(">"); + } + + TypeDefKind::Type(t) => self.print_ty(t, mode), + } + } + + fn print_list(&mut self, ty: &Type, mode: TypeMode) { + match mode { + TypeMode::AllBorrowed(lt) => { + self.push_str("&"); + if lt != "'_" { + self.push_str(lt); + self.push_str(" "); + } + self.push_str("["); + self.print_ty(ty, mode); + self.push_str("]"); + } + TypeMode::Owned => { + self.push_str("Vec<"); + self.print_ty(ty, mode); + self.push_str(">"); + } + } + } + + fn print_generics(&mut self, lifetime: Option<&str>) { + if lifetime.is_none() { + return; + } + self.push_str("<"); + if let Some(lt) = lifetime { + self.push_str(lt); + self.push_str(","); + } + self.push_str(">"); + } + + fn modes_of(&self, ty: TypeId) -> Vec<(String, TypeMode)> { + let info = self.info(ty); + let mut result = Vec::new(); + if info.param { + result.push((self.param_name(ty), self.default_param_mode())); + } + if info.result && (!info.param || self.uses_two_names(&info)) { + result.push((self.result_name(ty), TypeMode::Owned)); + } + return result; + } + + /// Writes the camel-cased 'name' of the passed type to `out`, as used to name union variants. + fn write_name(&self, ty: &Type, out: &mut String) { + match ty { + Type::Bool => out.push_str("Bool"), + Type::U8 => out.push_str("U8"), + Type::U16 => out.push_str("U16"), + Type::U32 => out.push_str("U32"), + Type::U64 => out.push_str("U64"), + Type::S8 => out.push_str("I8"), + Type::S16 => out.push_str("I16"), + Type::S32 => out.push_str("I32"), + Type::S64 => out.push_str("I64"), + Type::Float32 => out.push_str("F32"), + Type::Float64 => out.push_str("F64"), + Type::Char => out.push_str("Char"), + Type::String => out.push_str("String"), + Type::Id(id) => { + let ty = &self.iface().types[*id]; + match &ty.name { + Some(name) => out.push_str(&name.to_upper_camel_case()), + None => match &ty.kind { + TypeDefKind::Option(ty) => { + out.push_str("Optional"); + self.write_name(ty, out); + } + TypeDefKind::Result(_) => out.push_str("Result"), + TypeDefKind::Tuple(_) => out.push_str("Tuple"), + TypeDefKind::List(ty) => { + self.write_name(ty, out); + out.push_str("List") + } + TypeDefKind::Future(ty) => { + self.write_optional_name(ty.as_ref(), out); + out.push_str("Future"); + } + TypeDefKind::Stream(s) => { + self.write_optional_name(s.element.as_ref(), out); + self.write_optional_name(s.end.as_ref(), out); + out.push_str("Stream"); + } + + TypeDefKind::Type(ty) => self.write_name(ty, out), + TypeDefKind::Record(_) => out.push_str("Record"), + TypeDefKind::Flags(_) => out.push_str("Flags"), + TypeDefKind::Variant(_) => out.push_str("Variant"), + TypeDefKind::Enum(_) => out.push_str("Enum"), + TypeDefKind::Union(_) => out.push_str("Union"), + }, + } + } + } + } + + fn write_optional_name(&self, ty: Option<&Type>, out: &mut String) { + match ty { + Some(ty) => self.write_name(ty, out), + None => out.push_str("()"), + } + } + + /// Returns the names for the cases of the passed union. + fn union_case_names(&self, union: &Union) -> Vec { + enum UsedState<'a> { + /// This name has been used once before. + /// + /// Contains a reference to the name given to the first usage so that a suffix can be added to it. + Once(&'a mut String), + /// This name has already been used multiple times. + /// + /// Contains the number of times this has already been used. + Multiple(usize), + } + + // A `Vec` of the names we're assigning each of the union's cases in order. + let mut case_names = vec![String::new(); union.cases.len()]; + // A map from case names to their `UsedState`. + let mut used = HashMap::new(); + for (case, name) in union.cases.iter().zip(case_names.iter_mut()) { + self.write_name(&case.ty, name); + + match used.get_mut(name.as_str()) { + None => { + // Initialise this name's `UsedState`, with a mutable reference to this name + // in case we have to add a suffix to it later. + used.insert(name.clone(), UsedState::Once(name)); + // Since this is the first (and potentially only) usage of this name, + // we don't need to add a suffix here. + } + Some(state) => match state { + UsedState::Multiple(n) => { + // Add a suffix of the index of this usage. + write!(name, "{n}").unwrap(); + // Add one to the number of times this type has been used. + *n += 1; + } + UsedState::Once(first) => { + // Add a suffix of 0 to the first usage. + first.push('0'); + // We now get a suffix of 1. + name.push('1'); + // Then update the state. + *state = UsedState::Multiple(2); + } + }, + } + } + + case_names + } + + fn param_name(&self, ty: TypeId) -> String { + let info = self.info(ty); + let name = self.iface().types[ty] + .name + .as_ref() + .unwrap() + .to_upper_camel_case(); + if self.uses_two_names(&info) { + format!("{}Param", name) + } else { + name + } + } + + fn result_name(&self, ty: TypeId) -> String { + let info = self.info(ty); + let name = self.iface().types[ty] + .name + .as_ref() + .unwrap() + .to_upper_camel_case(); + if self.uses_two_names(&info) { + format!("{}Result", name) + } else { + name + } + } + + fn uses_two_names(&self, info: &TypeInfo) -> bool { + info.has_list + && info.param + && info.result + && match self.default_param_mode() { + TypeMode::AllBorrowed(_) => true, + TypeMode::Owned => false, + } + } + + fn lifetime_for(&self, info: &TypeInfo, mode: TypeMode) -> Option<&'static str> { + match mode { + TypeMode::AllBorrowed(s) if info.has_list => Some(s), + _ => None, + } + } +} + +pub fn to_rust_ident(name: &str) -> String { + match name { + // Escape Rust keywords. + // Source: https://doc.rust-lang.org/reference/keywords.html + "as" => "as_".into(), + "break" => "break_".into(), + "const" => "const_".into(), + "continue" => "continue_".into(), + "crate" => "crate_".into(), + "else" => "else_".into(), + "enum" => "enum_".into(), + "extern" => "extern_".into(), + "false" => "false_".into(), + "fn" => "fn_".into(), + "for" => "for_".into(), + "if" => "if_".into(), + "impl" => "impl_".into(), + "in" => "in_".into(), + "let" => "let_".into(), + "loop" => "loop_".into(), + "match" => "match_".into(), + "mod" => "mod_".into(), + "move" => "move_".into(), + "mut" => "mut_".into(), + "pub" => "pub_".into(), + "ref" => "ref_".into(), + "return" => "return_".into(), + "self" => "self_".into(), + "static" => "static_".into(), + "struct" => "struct_".into(), + "super" => "super_".into(), + "trait" => "trait_".into(), + "true" => "true_".into(), + "type" => "type_".into(), + "unsafe" => "unsafe_".into(), + "use" => "use_".into(), + "where" => "where_".into(), + "while" => "while_".into(), + "async" => "async_".into(), + "await" => "await_".into(), + "dyn" => "dyn_".into(), + "abstract" => "abstract_".into(), + "become" => "become_".into(), + "box" => "box_".into(), + "do" => "do_".into(), + "final" => "final_".into(), + "macro" => "macro_".into(), + "override" => "override_".into(), + "priv" => "priv_".into(), + "typeof" => "typeof_".into(), + "unsized" => "unsized_".into(), + "virtual" => "virtual_".into(), + "yield" => "yield_".into(), + "try" => "try_".into(), + s => s.to_snake_case(), + } +} diff --git a/crates/wit-bindgen/src/source.rs b/crates/wit-bindgen/src/source.rs new file mode 100644 index 000000000000..f7099f49edf1 --- /dev/null +++ b/crates/wit-bindgen/src/source.rs @@ -0,0 +1,130 @@ +use std::fmt::{self, Write}; +use std::ops::Deref; + +/// Helper structure to maintain indentation automatically when printing. +#[derive(Default)] +pub struct Source { + s: String, + indent: usize, +} + +impl Source { + pub fn push_str(&mut self, src: &str) { + let lines = src.lines().collect::>(); + for (i, line) in lines.iter().enumerate() { + let trimmed = line.trim(); + if trimmed.starts_with('}') && self.s.ends_with(" ") { + self.s.pop(); + self.s.pop(); + } + self.s.push_str(if lines.len() == 1 { + line + } else { + line.trim_start() + }); + if trimmed.ends_with('{') { + self.indent += 1; + } + if trimmed.starts_with('}') { + // Note that a `saturating_sub` is used here to prevent a panic + // here in the case of invalid code being generated in debug + // mode. It's typically easier to debug those issues through + // looking at the source code rather than getting a panic. + self.indent = self.indent.saturating_sub(1); + } + if i != lines.len() - 1 || src.ends_with('\n') { + self.newline(); + } + } + } + + pub fn indent(&mut self, amt: usize) { + self.indent += amt; + } + + pub fn deindent(&mut self, amt: usize) { + self.indent -= amt; + } + + fn newline(&mut self) { + self.s.push('\n'); + for _ in 0..self.indent { + self.s.push_str(" "); + } + } + + pub fn as_mut_string(&mut self) -> &mut String { + &mut self.s + } +} + +impl Write for Source { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.push_str(s); + Ok(()) + } +} + +impl Deref for Source { + type Target = str; + fn deref(&self) -> &str { + &self.s + } +} + +impl From for String { + fn from(s: Source) -> String { + s.s + } +} + +#[cfg(test)] +mod tests { + use super::Source; + + #[test] + fn simple_append() { + let mut s = Source::default(); + s.push_str("x"); + assert_eq!(s.s, "x"); + s.push_str("y"); + assert_eq!(s.s, "xy"); + s.push_str("z "); + assert_eq!(s.s, "xyz "); + s.push_str(" a "); + assert_eq!(s.s, "xyz a "); + s.push_str("\na"); + assert_eq!(s.s, "xyz a \na"); + } + + #[test] + fn newline_remap() { + let mut s = Source::default(); + s.push_str("function() {\n"); + s.push_str("y\n"); + s.push_str("}\n"); + assert_eq!(s.s, "function() {\n y\n}\n"); + } + + #[test] + fn if_else() { + let mut s = Source::default(); + s.push_str("if() {\n"); + s.push_str("y\n"); + s.push_str("} else if () {\n"); + s.push_str("z\n"); + s.push_str("}\n"); + assert_eq!(s.s, "if() {\n y\n} else if () {\n z\n}\n"); + } + + #[test] + fn trim_ws() { + let mut s = Source::default(); + s.push_str( + "function() { + x + }", + ); + assert_eq!(s.s, "function() {\n x\n}"); + } +} diff --git a/crates/wit-bindgen/src/types.rs b/crates/wit-bindgen/src/types.rs new file mode 100644 index 000000000000..3709987f3186 --- /dev/null +++ b/crates/wit-bindgen/src/types.rs @@ -0,0 +1,207 @@ +use std::collections::HashMap; +use wit_parser::*; + +#[derive(Default)] +pub struct Types { + type_info: HashMap, +} + +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub struct TypeInfo { + /// Whether or not this type is ever used (transitively) within the + /// parameter of a function. + pub param: bool, + + /// Whether or not this type is ever used (transitively) within the + /// result of a function. + pub result: bool, + + /// Whether or not this type is ever used (transitively) within the + /// error case in the result of a function. + pub error: bool, + + /// Whether or not this type (transitively) has a list. + pub has_list: bool, +} + +impl std::ops::BitOrAssign for TypeInfo { + fn bitor_assign(&mut self, rhs: Self) { + self.param |= rhs.param; + self.result |= rhs.result; + self.error |= rhs.error; + self.has_list |= rhs.has_list; + } +} + +impl Types { + pub fn analyze(&mut self, iface: &Interface) { + for (t, _) in iface.types.iter() { + self.type_id_info(iface, t); + } + for f in iface.functions.iter() { + for (_, ty) in f.params.iter() { + self.set_param_result_ty( + iface, + ty, + TypeInfo { + param: true, + ..TypeInfo::default() + }, + ); + } + for ty in f.results.iter_types() { + self.set_param_result_ty( + iface, + ty, + TypeInfo { + result: true, + ..TypeInfo::default() + }, + ); + } + } + } + + pub fn get(&self, id: TypeId) -> TypeInfo { + self.type_info[&id] + } + + fn type_id_info(&mut self, iface: &Interface, ty: TypeId) -> TypeInfo { + if let Some(info) = self.type_info.get(&ty) { + return *info; + } + let mut info = TypeInfo::default(); + match &iface.types[ty].kind { + TypeDefKind::Record(r) => { + for field in r.fields.iter() { + info |= self.type_info(iface, &field.ty); + } + } + TypeDefKind::Tuple(t) => { + for ty in t.types.iter() { + info |= self.type_info(iface, ty); + } + } + TypeDefKind::Flags(_) => {} + TypeDefKind::Enum(_) => {} + TypeDefKind::Variant(v) => { + for case in v.cases.iter() { + info |= self.optional_type_info(iface, case.ty.as_ref()); + } + } + TypeDefKind::List(ty) => { + info = self.type_info(iface, ty); + info.has_list = true; + } + TypeDefKind::Type(ty) => { + info = self.type_info(iface, ty); + } + TypeDefKind::Option(ty) => { + info = self.type_info(iface, ty); + } + TypeDefKind::Result(r) => { + info = self.optional_type_info(iface, r.ok.as_ref()); + info |= self.optional_type_info(iface, r.err.as_ref()); + } + TypeDefKind::Union(u) => { + for case in u.cases.iter() { + info |= self.type_info(iface, &case.ty); + } + } + TypeDefKind::Future(ty) => { + info = self.optional_type_info(iface, ty.as_ref()); + } + TypeDefKind::Stream(stream) => { + info = self.optional_type_info(iface, stream.element.as_ref()); + info |= self.optional_type_info(iface, stream.end.as_ref()); + } + } + self.type_info.insert(ty, info); + info + } + + fn type_info(&mut self, iface: &Interface, ty: &Type) -> TypeInfo { + let mut info = TypeInfo::default(); + match ty { + Type::String => info.has_list = true, + Type::Id(id) => return self.type_id_info(iface, *id), + _ => {} + } + info + } + + fn optional_type_info(&mut self, iface: &Interface, ty: Option<&Type>) -> TypeInfo { + match ty { + Some(ty) => self.type_info(iface, ty), + None => TypeInfo::default(), + } + } + + fn set_param_result_id(&mut self, iface: &Interface, ty: TypeId, info: TypeInfo) { + match &iface.types[ty].kind { + TypeDefKind::Record(r) => { + for field in r.fields.iter() { + self.set_param_result_ty(iface, &field.ty, info) + } + } + TypeDefKind::Tuple(t) => { + for ty in t.types.iter() { + self.set_param_result_ty(iface, ty, info) + } + } + TypeDefKind::Flags(_) => {} + TypeDefKind::Enum(_) => {} + TypeDefKind::Variant(v) => { + for case in v.cases.iter() { + self.set_param_result_optional_ty(iface, case.ty.as_ref(), info) + } + } + TypeDefKind::List(ty) | TypeDefKind::Type(ty) | TypeDefKind::Option(ty) => { + self.set_param_result_ty(iface, ty, info) + } + TypeDefKind::Result(r) => { + self.set_param_result_optional_ty(iface, r.ok.as_ref(), info); + let mut info2 = info; + info2.error = info.result; + self.set_param_result_optional_ty(iface, r.err.as_ref(), info2); + } + TypeDefKind::Union(u) => { + for case in u.cases.iter() { + self.set_param_result_ty(iface, &case.ty, info) + } + } + TypeDefKind::Future(ty) => self.set_param_result_optional_ty(iface, ty.as_ref(), info), + TypeDefKind::Stream(stream) => { + self.set_param_result_optional_ty(iface, stream.element.as_ref(), info); + self.set_param_result_optional_ty(iface, stream.end.as_ref(), info); + } + } + } + + fn set_param_result_ty(&mut self, iface: &Interface, ty: &Type, info: TypeInfo) { + match ty { + Type::Id(id) => { + self.type_id_info(iface, *id); + let cur = self.type_info.get_mut(id).unwrap(); + let prev = *cur; + *cur |= info; + if prev != *cur { + self.set_param_result_id(iface, *id, info); + } + } + _ => {} + } + } + + fn set_param_result_optional_ty( + &mut self, + iface: &Interface, + ty: Option<&Type>, + info: TypeInfo, + ) { + match ty { + Some(ty) => self.set_param_result_ty(iface, ty, info), + None => (), + } + } +} diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index e7e799b31a19..8ff2f897b2f3 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -967,3 +967,9 @@ criteria = "safe-to-deploy" version = "0.34.0" notes = "I am the author of this crate." +[[audits.wit-parser]] +who = "Alex Crichton " +criteria = "safe-to-deploy" +version = "0.3.0" +notes = "The Bytecode Alliance is the author of this crate." + From 0fdac5bf98f4f7e974447d6097a6ee0484e69a41 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 22 Nov 2022 09:14:44 -0800 Subject: [PATCH 02/13] Refactor `wasmtime-component-macro` a bit Make room for a `wit_bindgen` macro to slot in. --- crates/component-macro/Cargo.toml | 1 + crates/component-macro/src/component.rs | 1191 ++++++++++++++++++++++ crates/component-macro/src/lib.rs | 1216 +---------------------- 3 files changed, 1209 insertions(+), 1199 deletions(-) create mode 100644 crates/component-macro/src/component.rs diff --git a/crates/component-macro/Cargo.toml b/crates/component-macro/Cargo.toml index 882f497092d9..47fb033c0747 100644 --- a/crates/component-macro/Cargo.toml +++ b/crates/component-macro/Cargo.toml @@ -18,6 +18,7 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["extra-traits"] } wasmtime-component-util = { workspace = true } +wasmtime-wit-bindgen = { workspace = true } [badges] maintenance = { status = "actively-developed" } diff --git a/crates/component-macro/src/component.rs b/crates/component-macro/src/component.rs new file mode 100644 index 000000000000..b87fb6cb3c42 --- /dev/null +++ b/crates/component-macro/src/component.rs @@ -0,0 +1,1191 @@ +use proc_macro2::{Literal, TokenStream, TokenTree}; +use quote::{format_ident, quote}; +use std::collections::HashSet; +use std::fmt; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{braced, parse_quote, Data, DeriveInput, Error, Result, Token}; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; + +#[derive(Debug, Copy, Clone)] +pub enum VariantStyle { + Variant, + Enum, + Union, +} + +impl fmt::Display for VariantStyle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::Variant => "variant", + Self::Enum => "enum", + Self::Union => "union", + }) + } +} + +#[derive(Debug, Copy, Clone)] +enum Style { + Record, + Variant(VariantStyle), +} + +fn find_style(input: &DeriveInput) -> Result