diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6db47984..79ab7bf6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,6 +49,19 @@ jobs: cargo test -p ort --verbose --features fetch-models -- --test-threads 1 # Test examples that use in-tree graphs (do NOT run any of the examples that download ~700 MB graphs from pyke parcel...) cargo run --example custom-ops + test-wasm: + name: Test WebAssembly + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install nightly Rust toolchain + uses: dtolnay/rust-toolchain@nightly + - name: Install wasm-pack + run: cargo install wasm-pack + - name: Run tests + working-directory: examples/webassembly + run: | + wasm-pack test --node # Disable cross-compile until cross updates aarch64-unknown-linux-gnu to Ubuntu 22.04 # ref https://github.com/cross-rs/cross/pull/973 #cross-compile: diff --git a/.gitignore b/.gitignore index a76dbd4f..bf1af90f 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ WixTools/ # ONNX Runtime downloaded models **/*.onnx **/*.ort +!examples/webassembly/**/*.ort !tests/data/*.onnx !tests/data/*.ort diff --git a/Cargo.toml b/Cargo.toml index 05690ad4..5939f274 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ members = [ 'examples/gpt2', 'examples/model-info', 'examples/yolov8', - 'examples/modnet' + 'examples/modnet', + 'examples/webassembly' ] default-members = [ '.', @@ -45,6 +46,7 @@ codegen-units = 1 [package.metadata.docs.rs] features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] +targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] [features] @@ -91,6 +93,10 @@ libc = { version = "0.2", optional = true } [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", optional = true, features = [ "std", "libloaderapi" ] } +[target.'cfg(target_family = "wasm")'.dependencies] +js-sys = "0.3" +web-sys = "0.3" + [dev-dependencies] anyhow = "1.0" ureq = "2.1" @@ -100,6 +106,7 @@ tracing-subscriber = { version = "0.3", default-features = false, features = [ " glassbench = "0.4" tokio = { version = "1.36", features = [ "test-util" ] } tokio-test = "0.4.3" +wasm-bindgen-test = "0.3" [[bench]] name = "squeezenet" diff --git a/examples/webassembly/Cargo.toml b/examples/webassembly/Cargo.toml new file mode 100644 index 00000000..4c235e54 --- /dev/null +++ b/examples/webassembly/Cargo.toml @@ -0,0 +1,25 @@ +[package] +publish = false +name = "example-webassembly" +version = "0.0.0" +edition = "2021" + +[lib] +name = "ortwasm" +crate-type = ["cdylib"] + +[dependencies] +ort = { path = "../../" } +ndarray = "0.15" +wasm-bindgen = "0.2.92" +web-sys = "0.3" +tracing = "0.1" +tracing-subscriber = "0.3" +tracing-subscriber-wasm = "0.1" + +[dev-dependencies] +wasm-bindgen-test = "0.3" +console_error_panic_hook = "0.1" + +[features] +load-dynamic = [ "ort/load-dynamic" ] diff --git a/examples/webassembly/src/lib.rs b/examples/webassembly/src/lib.rs new file mode 100644 index 00000000..589bc11d --- /dev/null +++ b/examples/webassembly/src/lib.rs @@ -0,0 +1,55 @@ +use ndarray::{Array4, ArrayViewD}; +use ort::Session; +use wasm_bindgen::prelude::*; + +static MODEL_BYTES: &[u8] = include_bytes!("upsample.ort"); + +pub fn upsample_inner() -> ort::Result<()> { + let session = Session::builder()? + .commit_from_memory_directly(MODEL_BYTES) + .expect("Could not read model from memory"); + + let array = Array4::::zeros((1, 224, 224, 3)); + + let outputs = session.run(ort::inputs![array]?)?; + + assert_eq!(outputs.len(), 1); + let output: ArrayViewD = outputs[0].try_extract_tensor()?; + + assert_eq!(output.shape(), [1, 448, 448, 3]); + + Ok(()) +} + +macro_rules! console_log { + ($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into())) +} + +#[wasm_bindgen] +pub fn upsample() { + if let Err(e) = upsample_inner() { + console_log!("Error occurred while performing inference: {e:?}"); + } +} + +#[cfg(test)] +#[wasm_bindgen_test::wasm_bindgen_test] +fn run_test() { + use tracing::Level; + use tracing_subscriber::fmt; + use tracing_subscriber_wasm::MakeConsoleWriter; + + #[cfg(target_arch = "wasm32")] + ort::wasm::initialize(); + + fmt() + .with_ansi(false) + .with_max_level(Level::DEBUG) + .with_writer(MakeConsoleWriter::default().map_trace_level_to(Level::DEBUG)) + .without_time() + .init(); + + std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + + upsample(); +} diff --git a/examples/webassembly/src/upsample.ort b/examples/webassembly/src/upsample.ort new file mode 100644 index 00000000..b3e43d00 Binary files /dev/null and b/examples/webassembly/src/upsample.ort differ diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 905b492a..48cf8e6d 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -297,6 +297,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) { "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-msort_static-v1.17.1-aarch64-unknown-linux-gnu.tgz", "73A569FF807D655FD6258816FBC9660667370AEB4A47C6754746BCBF07C280F9" ), + "wasm32-wasi" | "wasm32-unknown-unknown" => ( + "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz", + "41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089" + ), "wasm32-unknown-emscripten" => ( "https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-msort_static-v1.17.1-wasm32-unknown-emscripten.tgz", "58EAD204FE53A488489287FFD97113E89A2CCA91876D3186CDBCA10A4F5A3287" diff --git a/src/lib.rs b/src/lib.rs index 69d9b95c..03b0b3f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,11 +24,14 @@ pub(crate) mod operator; pub(crate) mod session; pub(crate) mod tensor; pub(crate) mod value; +#[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] +#[cfg(target_arch = "wasm32")] +pub mod wasm; #[cfg(feature = "load-dynamic")] use std::sync::Arc; use std::{ - ffi::{self, CStr}, + ffi::CStr, os::raw::c_char, ptr::{self, NonNull}, sync::{ @@ -154,7 +157,7 @@ pub fn api() -> NonNull { let base: *const ort_sys::OrtApiBase = base_getter(); assert_ne!(base, ptr::null()); - let get_version_string: extern_system_fn! { unsafe fn () -> *const ffi::c_char } = + let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } = (*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`"); let version_string = get_version_string(); let version_string = CStr::from_ptr(version_string).to_string_lossy(); @@ -251,11 +254,13 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result { #[cfg(test)] mod test { + use std::ffi::CString; + use super::*; #[test] fn test_char_p_to_string() { - let s = ffi::CString::new("foo").unwrap_or_else(|_| unreachable!()); + let s = CString::new("foo").unwrap_or_else(|_| unreachable!()); let ptr = s.as_c_str().as_ptr(); assert_eq!("foo", char_p_to_string(ptr).expect("failed to convert string")); } diff --git a/src/session/builder.rs b/src/session/builder.rs index 6188deeb..60f716e0 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -6,13 +6,13 @@ use std::os::windows::ffi::OsStrExt; use std::path::PathBuf; use std::{ any::Any, - ffi::CString, marker::PhantomData, - path::Path, ptr::{self, NonNull}, rc::Rc, sync::{atomic::Ordering, Arc} }; +#[cfg(not(target_arch = "wasm32"))] +use std::{ffi::CString, path::Path}; use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; #[cfg(feature = "fetch-models")] @@ -166,6 +166,8 @@ impl SessionBuilder { /// newly optimized model to the given path (for 'offline' graph optimization). /// /// Note that the file will only be created after the model is committed. + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))] pub fn with_optimized_model_path>(self, path: S) -> Result { #[cfg(windows)] let path = path.as_ref().encode_utf16().chain([0]).collect::>(); @@ -177,6 +179,8 @@ impl SessionBuilder { /// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes. /// See [`Session::end_profiling`]. + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))] pub fn with_profiling>(self, profiling_file: S) -> Result { #[cfg(windows)] let profiling_file = profiling_file.as_ref().encode_utf16().chain([0]).collect::>(); @@ -205,8 +209,8 @@ impl SessionBuilder { } /// Registers a custom operator library at the given library path. - #[cfg(feature = "operator-libraries")] - #[cfg_attr(docsrs, doc(cfg(feature = "operator-libraries")))] + #[cfg(all(feature = "operator-libraries", not(target_arch = "wasm32")))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "operator-libraries", not(target_arch = "wasm32")))))] pub fn with_operator_library(mut self, lib_path: impl AsRef) -> Result { let path_cstr = CString::new(lib_path.as_ref())?; @@ -232,6 +236,8 @@ impl SessionBuilder { } /// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators. + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))] pub fn with_extensions(self) -> Result { let status = ortsys![unsafe EnableOrtCustomOps(self.session_options_ptr.as_ptr())]; status_to_result(status).map_err(Error::CreateSessionOptions)?; @@ -246,8 +252,8 @@ impl SessionBuilder { } /// Downloads a pre-trained ONNX model from the given URL and builds the session. - #[cfg(feature = "fetch-models")] - #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] + #[cfg(all(feature = "fetch-models", not(target_arch = "wasm32")))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", not(target_arch = "wasm32")))))] pub fn commit_from_url(self, model_url: impl AsRef) -> Result { let mut download_dir = ort_sys::internal::dirs::cache_dir() .expect("could not determine cache directory") @@ -294,6 +300,8 @@ impl SessionBuilder { } /// Loads an ONNX model from a file and builds the session. + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))] pub fn commit_from_file

(mut self, model_filepath_ref: P) -> Result where P: AsRef diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index a94dcfab..4229537e 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -110,7 +110,7 @@ impl Value`] to a reference to a type-erased [`DynTensor`]. + /// Converts from a strongly-typed [`Sequence`] to a reference to a type-erased [`DynSequence`]. #[inline] pub fn upcast_ref(&self) -> DynSequenceRef { DynSequenceRef::new(unsafe { @@ -121,7 +121,7 @@ impl Value`] to a mutable reference to a type-erased [`DynTensor`]. + /// Converts from a strongly-typed [`Sequence`] to a mutable reference to a type-erased [`DynSequence`]. #[inline] pub fn upcast_mut(&mut self) -> DynSequenceRefMut { DynSequenceRefMut::new(unsafe { diff --git a/src/wasm.rs b/src/wasm.rs new file mode 100644 index 00000000..fe8bdf6d --- /dev/null +++ b/src/wasm.rs @@ -0,0 +1,230 @@ +//! Utilities for using `ort` in WebAssembly. +//! +//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs: +//! ``` +//! # use ort::Session; +//! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort"); +//! # fn main() -> ort::Result<()> { +//! #[cfg(target_arch = "wasm32")] +//! ort::wasm::initialize(); +//! +//! let session = Session::builder()?.commit_from_memory_directly(MODEL_BYTES)?; +//! # Ok(()) +//! # } +//! ``` + +use std::{ + alloc::{self, Layout}, + arch::wasm32, + ptr, slice, str +}; + +mod fmt_shims { + // localized time string formatting functions + // TODO: remove any remaining codepaths to these + + #[no_mangle] + pub unsafe extern "C" fn strftime_l(_s: *mut u8, _l: usize, _m: *const u8, _t: *const (), _lt: *const ()) -> usize { + unimplemented!() + } + #[no_mangle] + pub unsafe extern "C" fn _tzset_js(_timezone: *mut u32, _daylight: *const i32, _name: *const u8, _dst_name: *mut u8) { + unimplemented!() + } + #[no_mangle] + pub unsafe extern "C" fn _mktime_js(_tm: *mut ()) -> ! { + unimplemented!() + } + #[no_mangle] + pub unsafe extern "C" fn _localtime_js(_time_t: i64, _tm: *mut ()) -> ! { + unimplemented!() + } + #[no_mangle] + pub unsafe extern "C" fn _gmtime_js(_time_t: i64, _tm: *mut ()) -> ! { + unimplemented!() + } +} + +pub(crate) mod libc_shims { + use super::*; + + // Rust, unlike C, requires us to know the exact layout of an allocation in order to deallocate it, so we need to + // store this data at the beginning of the allocation for us to be able to pick up on deallocation: + // + // ┌---- actual allocated pointer + // ▼ + // +-------------+-------+------+----------------+ + // | ...padding | align | size | data... | + // | -align..-8 | -8 | -4 | 0..size | + // +-------------+- -----+------+----------------+ + // ▲ + // pointer returned to C ---┘ + // + // This does unfortunately mean we waste a little extra memory (note that most allocators *also* store the layout + // information in a similar manner, but we can't access it). + + const _: () = assert!(std::mem::size_of::() == 4, "32-bit pointer width (wasm32) required"); + + unsafe fn alloc_inner(size: usize, align: usize) -> *mut u8 { + // need enough space to store the size & alignment bytes + let align = align.max(8); + + let layout = Layout::from_size_align_unchecked(size + align, align); + let ptr = if ZERO { alloc::alloc_zeroed(layout) } else { alloc::alloc(layout) }; + ptr::copy_nonoverlapping(size.to_le_bytes().as_ptr(), ptr.add(align - 4), 4); + ptr::copy_nonoverlapping(align.to_le_bytes().as_ptr(), ptr.add(align - 8), 4); + ptr.add(align) + } + + unsafe fn free_inner(ptr: *mut u8) { + // something likes to free(NULL) a lot, which is valid in C (because of course it is...) + if ptr.is_null() { + return; + } + + let size = usize::from_le_bytes(slice::from_raw_parts_mut(ptr.sub(4), 4).try_into().unwrap_unchecked()); + let align = usize::from_le_bytes(slice::from_raw_parts_mut(ptr.sub(8), 4).try_into().unwrap_unchecked()); + let layout = Layout::from_size_align_unchecked(size + align, align); + alloc::dealloc(ptr.sub(align), layout); + } + + const DEFAULT_ALIGNMENT: usize = 32; + + #[no_mangle] + pub unsafe extern "C" fn malloc(size: usize) -> *mut u8 { + alloc_inner::(size, DEFAULT_ALIGNMENT) + } + #[no_mangle] + pub unsafe extern "C" fn __libc_malloc(size: usize) -> *mut u8 { + alloc_inner::(size, DEFAULT_ALIGNMENT) + } + #[no_mangle] + pub unsafe extern "C" fn __libc_calloc(n: usize, size: usize) -> *mut u8 { + alloc_inner::(size * n, DEFAULT_ALIGNMENT) + } + #[no_mangle] + pub unsafe extern "C" fn free(ptr: *mut u8) { + free_inner(ptr) + } + #[no_mangle] + pub unsafe extern "C" fn __libc_free(ptr: *mut u8) { + free_inner(ptr) + } + + #[no_mangle] + pub unsafe extern "C" fn posix_memalign(ptr: *mut *mut u8, align: usize, size: usize) -> i32 { + *ptr = alloc_inner::(size, align); + 0 + } + + #[no_mangle] + pub unsafe extern "C" fn realloc(ptr: *mut u8, newsize: usize) -> *mut u8 { + let size = usize::from_le_bytes(slice::from_raw_parts_mut(ptr.sub(4), 4).try_into().unwrap_unchecked()); + let align = usize::from_le_bytes(slice::from_raw_parts_mut(ptr.sub(8), 4).try_into().unwrap_unchecked()); + let layout = Layout::from_size_align_unchecked(size + align, align); + let ptr = alloc::realloc(ptr.sub(align), layout, newsize); + ptr::copy_nonoverlapping(size.to_le_bytes().as_ptr(), ptr.add(align - 4), 4); + ptr.add(align) + } + + #[no_mangle] + pub unsafe extern "C" fn abort() -> ! { + std::process::abort() + } +} + +#[cfg(not(target_os = "wasi"))] +mod wasi_shims { + #[allow(non_camel_case_types)] + type __wasi_errno_t = u16; + + const __WASI_ENOTSUP: __wasi_errno_t = 58; + + // mock filesystem for non-WASI platforms - most of the codepaths to any FS operations should've been removed, but we + // return ENOTSUP just to be safe + + #[no_mangle] + pub unsafe extern "C" fn __wasi_environ_sizes_get(argc: *mut usize, argv_buf_size: *mut usize) -> __wasi_errno_t { + *argc = 0; + *argv_buf_size = 0; + __WASI_ENOTSUP + } + + #[no_mangle] + pub unsafe extern "C" fn __wasi_environ_get(_environ: *mut *mut u8, _buf: *mut u8) -> __wasi_errno_t { + __WASI_ENOTSUP + } + + #[no_mangle] + pub unsafe extern "C" fn __wasi_fd_seek(_fd: u32, _offset: i64, _whence: u8, _new_offset: *mut u64) -> __wasi_errno_t { + __WASI_ENOTSUP + } + #[no_mangle] + pub unsafe extern "C" fn __wasi_fd_write(_fd: u32, _iovs: *const (), _iovs_len: usize, _nwritten: *mut usize) -> __wasi_errno_t { + __WASI_ENOTSUP + } + #[no_mangle] + pub unsafe extern "C" fn __wasi_fd_read(_fd: u32, _iovs: *const (), _iovs_len: usize, _nread: *mut usize) -> __wasi_errno_t { + __WASI_ENOTSUP + } + #[no_mangle] + pub unsafe extern "C" fn __wasi_fd_close(_fd: u32) -> __wasi_errno_t { + __WASI_ENOTSUP + } +} + +mod emscripten_shims { + use super::*; + + #[no_mangle] + pub unsafe extern "C" fn emscripten_memcpy_js(dst: *mut (), src: *const (), n: usize) { + std::ptr::copy_nonoverlapping(src, dst, n) + } + + #[no_mangle] + pub unsafe extern "C" fn emscripten_get_now() -> f64 { + js_sys::Date::now() + } + + #[no_mangle] + pub unsafe extern "C" fn emscripten_get_heap_max() -> usize { + wasm32::memory_size(0) << 16 + } + + #[no_mangle] + pub unsafe extern "C" fn emscripten_date_now() -> f64 { + js_sys::Date::now() + } + + #[no_mangle] + pub unsafe extern "C" fn _emscripten_get_now_is_monotonic() -> i32 { + 0 + } + + #[no_mangle] + pub unsafe extern "C" fn emscripten_builtin_malloc(size: usize) -> *mut u8 { + alloc::alloc_zeroed(Layout::from_size_align_unchecked(size, 32)) + } + + #[no_mangle] + #[tracing::instrument] + pub unsafe extern "C" fn emscripten_errn(str: *const u8, len: usize) { + let c = str::from_utf8_unchecked(slice::from_raw_parts(str, len)); + tracing::error!("Emscripten error: {c}"); + } +} + +#[no_mangle] +#[export_name = "_initialize"] +pub fn initialize() { + // No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling + // `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate + // JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to + // handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling, + // and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this + // initialization at program start anyways to be safe. + extern "C" { + fn __wasm_call_ctors(); + } + unsafe { __wasm_call_ctors() }; +}