Skip to content

Commit

Permalink
feat: support for wasm32-unknown-unknown target (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 authored Apr 25, 2024
1 parent cedeb55 commit b12c43c
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 12 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ jobs:
cargo test -p ort --verbose --features fetch-models -- --test-threads 1
# Test examples that use in-tree graphs (do NOT run any of the examples that download ~700 MB graphs from pyke parcel...)
cargo run --example custom-ops
test-wasm:
name: Test WebAssembly
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install nightly Rust toolchain
uses: dtolnay/rust-toolchain@nightly
- name: Install wasm-pack
run: cargo install wasm-pack
- name: Run tests
working-directory: examples/webassembly
run: |
wasm-pack test --node
# Disable cross-compile until cross updates aarch64-unknown-linux-gnu to Ubuntu 22.04
# ref https://github.com/cross-rs/cross/pull/973
#cross-compile:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ WixTools/
# ONNX Runtime downloaded models
**/*.onnx
**/*.ort
!examples/webassembly/**/*.ort
!tests/data/*.onnx
!tests/data/*.ort

Expand Down
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ members = [
'examples/gpt2',
'examples/model-info',
'examples/yolov8',
'examples/modnet'
'examples/modnet',
'examples/webassembly'
]
default-members = [
'.',
Expand Down Expand Up @@ -45,6 +46,7 @@ codegen-units = 1

[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]

[features]
Expand Down Expand Up @@ -91,6 +93,10 @@ libc = { version = "0.2", optional = true }
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", optional = true, features = [ "std", "libloaderapi" ] }

[target.'cfg(target_family = "wasm")'.dependencies]
js-sys = "0.3"
web-sys = "0.3"

[dev-dependencies]
anyhow = "1.0"
ureq = "2.1"
Expand All @@ -100,6 +106,7 @@ tracing-subscriber = { version = "0.3", default-features = false, features = [ "
glassbench = "0.4"
tokio = { version = "1.36", features = [ "test-util" ] }
tokio-test = "0.4.3"
wasm-bindgen-test = "0.3"

[[bench]]
name = "squeezenet"
Expand Down
25 changes: 25 additions & 0 deletions examples/webassembly/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
publish = false
name = "example-webassembly"
version = "0.0.0"
edition = "2021"

[lib]
name = "ortwasm"
crate-type = ["cdylib"]

[dependencies]
ort = { path = "../../" }
ndarray = "0.15"
wasm-bindgen = "0.2.92"
web-sys = "0.3"
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber-wasm = "0.1"

[dev-dependencies]
wasm-bindgen-test = "0.3"
console_error_panic_hook = "0.1"

[features]
load-dynamic = [ "ort/load-dynamic" ]
55 changes: 55 additions & 0 deletions examples/webassembly/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use ndarray::{Array4, ArrayViewD};
use ort::Session;
use wasm_bindgen::prelude::*;

static MODEL_BYTES: &[u8] = include_bytes!("upsample.ort");

pub fn upsample_inner() -> ort::Result<()> {
let session = Session::builder()?
.commit_from_memory_directly(MODEL_BYTES)
.expect("Could not read model from memory");

let array = Array4::<f32>::zeros((1, 224, 224, 3));

let outputs = session.run(ort::inputs![array]?)?;

assert_eq!(outputs.len(), 1);
let output: ArrayViewD<f32> = outputs[0].try_extract_tensor()?;

assert_eq!(output.shape(), [1, 448, 448, 3]);

Ok(())
}

macro_rules! console_log {
($($t:tt)*) => (web_sys::console::log_1(&format_args!($($t)*).to_string().into()))
}

#[wasm_bindgen]
pub fn upsample() {
if let Err(e) = upsample_inner() {
console_log!("Error occurred while performing inference: {e:?}");
}
}

#[cfg(test)]
#[wasm_bindgen_test::wasm_bindgen_test]
fn run_test() {
use tracing::Level;
use tracing_subscriber::fmt;
use tracing_subscriber_wasm::MakeConsoleWriter;

#[cfg(target_arch = "wasm32")]
ort::wasm::initialize();

fmt()
.with_ansi(false)
.with_max_level(Level::DEBUG)
.with_writer(MakeConsoleWriter::default().map_trace_level_to(Level::DEBUG))
.without_time()
.init();

std::panic::set_hook(Box::new(console_error_panic_hook::hook));

upsample();
}
Binary file added examples/webassembly/src/upsample.ort
Binary file not shown.
4 changes: 4 additions & 0 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-msort_static-v1.17.1-aarch64-unknown-linux-gnu.tgz",
"73A569FF807D655FD6258816FBC9660667370AEB4A47C6754746BCBF07C280F9"
),
"wasm32-wasi" | "wasm32-unknown-unknown" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz",
"41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089"
),
"wasm32-unknown-emscripten" => (
"https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-msort_static-v1.17.1-wasm32-unknown-emscripten.tgz",
"58EAD204FE53A488489287FFD97113E89A2CCA91876D3186CDBCA10A4F5A3287"
Expand Down
11 changes: 8 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ pub(crate) mod operator;
pub(crate) mod session;
pub(crate) mod tensor;
pub(crate) mod value;
#[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))]
#[cfg(target_arch = "wasm32")]
pub mod wasm;

#[cfg(feature = "load-dynamic")]
use std::sync::Arc;
use std::{
ffi::{self, CStr},
ffi::CStr,
os::raw::c_char,
ptr::{self, NonNull},
sync::{
Expand Down Expand Up @@ -154,7 +157,7 @@ pub fn api() -> NonNull<ort_sys::OrtApi> {
let base: *const ort_sys::OrtApiBase = base_getter();
assert_ne!(base, ptr::null());

let get_version_string: extern_system_fn! { unsafe fn () -> *const ffi::c_char } =
let get_version_string: extern_system_fn! { unsafe fn () -> *const c_char } =
(*base).GetVersionString.expect("`GetVersionString` must be present in `OrtApiBase`");
let version_string = get_version_string();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
Expand Down Expand Up @@ -251,11 +254,13 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {

#[cfg(test)]
mod test {
use std::ffi::CString;

use super::*;

#[test]
fn test_char_p_to_string() {
let s = ffi::CString::new("foo").unwrap_or_else(|_| unreachable!());
let s = CString::new("foo").unwrap_or_else(|_| unreachable!());
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).expect("failed to convert string"));
}
Expand Down
20 changes: 14 additions & 6 deletions src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use std::os::windows::ffi::OsStrExt;
use std::path::PathBuf;
use std::{
any::Any,
ffi::CString,
marker::PhantomData,
path::Path,
ptr::{self, NonNull},
rc::Rc,
sync::{atomic::Ordering, Arc}
};
#[cfg(not(target_arch = "wasm32"))]
use std::{ffi::CString, path::Path};

use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner};
#[cfg(feature = "fetch-models")]
Expand Down Expand Up @@ -166,6 +166,8 @@ impl SessionBuilder {
/// newly optimized model to the given path (for 'offline' graph optimization).
///
/// Note that the file will only be created after the model is committed.
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub fn with_optimized_model_path<S: AsRef<str>>(self, path: S) -> Result<Self> {
#[cfg(windows)]
let path = path.as_ref().encode_utf16().chain([0]).collect::<Vec<_>>();
Expand All @@ -177,6 +179,8 @@ impl SessionBuilder {

/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
/// See [`Session::end_profiling`].
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub fn with_profiling<S: AsRef<str>>(self, profiling_file: S) -> Result<Self> {
#[cfg(windows)]
let profiling_file = profiling_file.as_ref().encode_utf16().chain([0]).collect::<Vec<_>>();
Expand Down Expand Up @@ -205,8 +209,8 @@ impl SessionBuilder {
}

/// Registers a custom operator library at the given library path.
#[cfg(feature = "operator-libraries")]
#[cfg_attr(docsrs, doc(cfg(feature = "operator-libraries")))]
#[cfg(all(feature = "operator-libraries", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "operator-libraries", not(target_arch = "wasm32")))))]
pub fn with_operator_library(mut self, lib_path: impl AsRef<str>) -> Result<Self> {
let path_cstr = CString::new(lib_path.as_ref())?;

Expand All @@ -232,6 +236,8 @@ impl SessionBuilder {
}

/// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators.
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub fn with_extensions(self) -> Result<Self> {
let status = ortsys![unsafe EnableOrtCustomOps(self.session_options_ptr.as_ptr())];
status_to_result(status).map_err(Error::CreateSessionOptions)?;
Expand All @@ -246,8 +252,8 @@ impl SessionBuilder {
}

/// Downloads a pre-trained ONNX model from the given URL and builds the session.
#[cfg(feature = "fetch-models")]
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
#[cfg(all(feature = "fetch-models", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "fetch-models", not(target_arch = "wasm32")))))]
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
Expand Down Expand Up @@ -294,6 +300,8 @@ impl SessionBuilder {
}

/// Loads an ONNX model from a file and builds the session.
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
where
P: AsRef<Path>
Expand Down
4 changes: 2 additions & 2 deletions src/value/impl_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
unsafe { std::mem::transmute(self) }
}

/// Converts from a strongly-typed [`Sequence<T>`] to a reference to a type-erased [`DynTensor`].
/// Converts from a strongly-typed [`Sequence<T>`] to a reference to a type-erased [`DynSequence`].
#[inline]
pub fn upcast_ref(&self) -> DynSequenceRef {
DynSequenceRef::new(unsafe {
Expand All @@ -121,7 +121,7 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
})
}

/// Converts from a strongly-typed [`Sequence<T>`] to a mutable reference to a type-erased [`DynTensor`].
/// Converts from a strongly-typed [`Sequence<T>`] to a mutable reference to a type-erased [`DynSequence`].
#[inline]
pub fn upcast_mut(&mut self) -> DynSequenceRefMut {
DynSequenceRefMut::new(unsafe {
Expand Down
Loading

0 comments on commit b12c43c

Please sign in to comment.