Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wasi-nn: update upstream specification #6853

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/wasi-nn/spec
Submodule spec updated 120 files
21 changes: 9 additions & 12 deletions crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,30 @@ mod openvino;

use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use crate::{ExecutionContext, Graph};
use thiserror::Error;
use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub(crate) fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
pub fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
pub(crate) trait Backend: Send + Sync {
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(
&mut self,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError>;
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
pub(crate) trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError>;
pub trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError>;
}

/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
pub(crate) trait BackendExecutionContext: Send + Sync {
pub trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
Expand All @@ -52,7 +49,7 @@ pub enum BackendError {
NotEnoughMemory(usize),
}

#[derive(Hash, PartialEq, Eq, Clone, Copy)]
pub(crate) enum BackendKind {
#[derive(Hash, PartialEq, Debug, Eq, Clone, Copy)]
pub enum BackendKind {
OpenVINO,
}
22 changes: 10 additions & 12 deletions crates/wasi-nn/src/backend/openvino.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;

Expand All @@ -15,11 +16,7 @@ impl Backend for OpenvinoBackend {
"openvino"
}

fn load(
&mut self,
builders: &[&[u8]],
target: ExecutionTarget,
) -> Result<Box<dyn BackendGraph>, BackendError> {
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 2 {
return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into());
}
Expand Down Expand Up @@ -54,8 +51,9 @@ impl Backend for OpenvinoBackend {

let exec_network =
core.load_network(&cnn_network, map_execution_target_to_string(target))?;

Ok(Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network)))
let box_: Box<dyn BackendGraph> =
Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network));
Ok(box_.into())
}
}

Expand All @@ -65,12 +63,11 @@ unsafe impl Send for OpenvinoGraph {}
unsafe impl Sync for OpenvinoGraph {}

impl BackendGraph for OpenvinoGraph {
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError> {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError> {
let infer_request = self.1.create_infer_request()?;
Ok(Box::new(OpenvinoExecutionContext(
self.0.clone(),
infer_request,
)))
let box_: Box<dyn BackendExecutionContext> =
Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request));
Ok(box_.into())
}
}

Expand Down Expand Up @@ -145,5 +142,6 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision
TensorType::Fp32 => Precision::FP32,
TensorType::U8 => Precision::U8,
TensorType::I32 => Precision::I32,
TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"),
}
}
34 changes: 18 additions & 16 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].

use crate::backend::{
self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind,
};
use crate::backend::{self, Backend, BackendError, BackendKind};
use crate::wit::types::GraphEncoding;
use std::collections::HashMap;
use std::hash::Hash;
use crate::{ExecutionContext, Graph};
use std::{collections::HashMap, hash::Hash};
use thiserror::Error;
use wiggle::GuestError;

type Backends = HashMap<BackendKind, Box<dyn Backend>>;
type GraphId = u32;
type GraphExecutionContextId = u32;

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<BackendKind, Box<dyn Backend>>,
pub(crate) graphs: Table<GraphId, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContextId, Box<dyn BackendExecutionContext>>,
pub(crate) backends: Backends,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new() -> WasiNnResult<Self> {
let mut backends = HashMap::new();
for (kind, backend) in backend::list() {
backends.insert(kind, backend);
}
Ok(Self {
pub fn new(backends: Backends) -> Self {
Self {
backends,
graphs: Table::default(),
executions: Table::default(),
})
}
}
}
impl Default for WasiNnCtx {
fn default() -> Self {
WasiNnCtx::new(backend::list().into_iter().collect())
}
}

Expand Down Expand Up @@ -59,6 +59,8 @@ pub enum UsageError {
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
#[error("No graph found with name: {0}")]
NotFound(String),
}

pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
Expand Down Expand Up @@ -105,6 +107,6 @@ mod test {

#[test]
fn instantiate() {
WasiNnCtx::new().unwrap();
WasiNnCtx::default();
}
}
38 changes: 38 additions & 0 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,41 @@ mod ctx;
pub use ctx::WasiNnCtx;
pub mod wit;
pub mod witx;

/// A backend-defined graph (i.e., ML model).
pub struct Graph(Box<dyn backend::BackendGraph>);
impl From<Box<dyn backend::BackendGraph>> for Graph {
fn from(value: Box<dyn backend::BackendGraph>) -> Self {
Self(value)
}
}
impl std::ops::Deref for Graph {
type Target = dyn backend::BackendGraph;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Graph {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}

/// A backend-defined execution context.
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
Self(value)
}
}
impl std::ops::Deref for ExecutionContext {
type Target = dyn backend::BackendExecutionContext;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for ExecutionContext {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
57 changes: 36 additions & 21 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,29 @@

use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx};

pub use gen::types;
pub use gen_::Ml as ML;

/// Generate the traits and types from the `wasi-nn` WIT specification.
mod gen_ {
wasmtime::component::bindgen!("ml");
wasmtime::component::bindgen!("ml" in "spec/wit/wasi-nn.wit");
}
use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need.

impl gen::inference::Host for WasiNnCtx {
// Export the `types` used in this crate as well as `ML::add_to_linker`.
pub mod types {
use super::gen;
pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding};
pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorType};
}
pub use gen_::Ml as ML;

impl gen::graph::Host for WasiNnCtx {
/// Load an opaque sequence of bytes to use for inference.
fn load(
&mut self,
builders: gen::types::GraphBuilderArray,
encoding: gen::types::GraphEncoding,
target: gen::types::ExecutionTarget,
) -> wasmtime::Result<Result<gen::types::Graph, gen::types::Error>> {
builders: Vec<gen::graph::GraphBuilder>,
encoding: gen::graph::GraphEncoding,
target: gen::graph::ExecutionTarget,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
let backend_kind: BackendKind = encoding.try_into()?;
let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
Expand All @@ -45,13 +51,22 @@ impl gen::inference::Host for WasiNnCtx {
Ok(Ok(graph_id))
}

fn load_by_name(
&mut self,
_name: String,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
todo!()
}
}

impl gen::inference::Host for WasiNnCtx {
/// Create an execution instance of a loaded graph.
///
/// TODO: remove completely?
fn init_execution_context(
&mut self,
graph_id: gen::types::Graph,
) -> wasmtime::Result<Result<gen::types::GraphExecutionContext, gen::types::Error>> {
graph_id: gen::graph::Graph,
) -> wasmtime::Result<Result<gen::inference::GraphExecutionContext, gen::errors::Error>> {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
graph.init_execution_context()?
} else {
Expand All @@ -65,10 +80,10 @@ impl gen::inference::Host for WasiNnCtx {
/// Define the inputs to use for inference.
fn set_input(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
tensor: gen::types::Tensor,
) -> wasmtime::Result<Result<(), gen::types::Error>> {
tensor: gen::tensor::Tensor,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.set_input(index, &tensor)?;
Ok(Ok(()))
Expand All @@ -82,8 +97,8 @@ impl gen::inference::Host for WasiNnCtx {
/// TODO: refactor to compute(list<tensor>) -> result<list<tensor>, error>
fn compute(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
) -> wasmtime::Result<Result<(), gen::types::Error>> {
exec_context_id: gen::inference::GraphExecutionContext,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.compute()?;
Ok(Ok(()))
Expand All @@ -95,9 +110,9 @@ impl gen::inference::Host for WasiNnCtx {
/// Extract the outputs after inference.
fn get_output(
&mut self,
exec_context_id: gen::types::GraphExecutionContext,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
) -> wasmtime::Result<Result<gen::types::TensorData, gen::types::Error>> {
) -> wasmtime::Result<Result<gen::tensor::TensorData, gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
// Read the output bytes. TODO: this involves a hard-coded upper
// limit on the tensor size that is necessary because there is no
Expand All @@ -113,11 +128,11 @@ impl gen::inference::Host for WasiNnCtx {
}
}

impl TryFrom<gen::types::GraphEncoding> for crate::backend::BackendKind {
impl TryFrom<gen::graph::GraphEncoding> for crate::backend::BackendKind {
type Error = UsageError;
fn try_from(value: gen::types::GraphEncoding) -> Result<Self, Self::Error> {
fn try_from(value: gen::graph::GraphEncoding) -> Result<Self, Self::Error> {
match value {
gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
gen::graph::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO),
_ => Err(UsageError::InvalidEncoding(value.into())),
}
}
Expand Down
Loading