diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index f2e7b9e0..923590b2 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -107,7 +107,22 @@ fn main() -> anyhow::Result<()> { ## Fallback behavior `ort` will silently fail and fall back to executing on the CPU if all execution providers fail to register. In many cases, though, you'll want to show the user an error message when an EP fails to register, or outright abort the process. -To receive these registration errors, instead use `ExecutionProvider::register` to register an execution provider: +You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`: +```rust +use ort::{CoreMLExecutionProvider, Session}; + +fn main() -> anyhow::Result<()> { + let session = Session::builder()? + .with_execution_providers([ + CUDAExecutionProvider::default().build().error_on_failure() + ])? + .commit_from_file("model.onnx")?; + + Ok(()) +} +``` + +If you require more complex error handling, you can also manually register execution providers via the `ExecutionProvider::register` method: ```rust use ort::{CUDAExecutionProvider, ExecutionProvider, Session}; diff --git a/examples/cudarc/src/main.rs b/examples/cudarc/src/main.rs index 1ffc01f0..20013a9f 100644 --- a/examples/cudarc/src/main.rs +++ b/examples/cudarc/src/main.rs @@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); ort::init() - .with_execution_providers([CUDAExecutionProvider::default().build()]) + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()]) .commit()?; let model = diff --git a/src/execution_providers/acl.rs b/src/execution_providers/acl.rs index a8e3bdb3..1f15ac70 100644 --- a/src/execution_providers/acl.rs +++ b/src/execution_providers/acl.rs @@ -26,7 +26,7 @@ impl ACLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ACLExecutionProvider) -> Self { - ExecutionProviderDispatch::ACL(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/armnn.rs b/src/execution_providers/armnn.rs index 53a38795..86332f01 100644 --- a/src/execution_providers/armnn.rs +++ b/src/execution_providers/armnn.rs @@ -26,7 +26,7 @@ impl ArmNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ArmNNExecutionProvider) -> Self { - ExecutionProviderDispatch::ArmNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index c43e8e06..f37a2f1b 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -109,7 +109,7 @@ impl CANNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CANNExecutionProvider) -> Self { - ExecutionProviderDispatch::CANN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 94971e8a..256de1e5 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -46,7 +46,7 @@ impl CoreMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CoreMLExecutionProvider) -> Self { - ExecutionProviderDispatch::CoreML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cpu.rs b/src/execution_providers/cpu.rs index 2f98095d..eb4be919 100644 --- a/src/execution_providers/cpu.rs +++ b/src/execution_providers/cpu.rs @@ -21,7 +21,7 @@ impl CPUExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CPUExecutionProvider) -> Self { - ExecutionProviderDispatch::CPU(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 77200e3c..17fbe825 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -161,7 +161,7 @@ impl CUDAExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CUDAExecutionProvider) -> Self { - ExecutionProviderDispatch::CUDA(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/directml.rs b/src/execution_providers/directml.rs index 71802553..38556f11 100644 --- a/src/execution_providers/directml.rs +++ b/src/execution_providers/directml.rs @@ -26,7 +26,7 @@ impl DirectMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: DirectMLExecutionProvider) -> Self { - ExecutionProviderDispatch::DirectML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 24ec6acf..8c237041 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, os::raw::c_char}; +use std::{fmt::Debug, os::raw::c_char, sync::Arc}; use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder}; @@ -60,16 +60,17 @@ pub trait ExecutionProvider { true } - /// Returns `Ok(true)` if ONNX Runtime was compiled with support for this execution provider, and `Ok(false)` + /// Returns `Ok(true)` if ONNX Runtime was *compiled with support* for this execution provider, and `Ok(false)` /// otherwise. /// /// An `Err` may be returned if a serious internal error occurs, in which case your application should probably /// just abort. /// - /// Note that this does not always mean the execution provider is *usable* for a specific model. A model may use - /// operators not supported by an execution provider, or the EP may encounter an error while attempting to load a - /// dynamic library during registration. In most cases (i.e. showing the user an error message if CUDA could not be - /// enabled), you'll instead want to detect and handle errors from [`ExecutionProvider::register`]. + /// **Note that this does not always mean the execution provider is *usable* for a specific session.** A model may + /// use operators not supported by an execution provider, or the EP may encounter an error while attempting to load + /// dependencies during session creation. In most cases (i.e. showing the user an error message if CUDA could not be + /// enabled), you'll instead want to manually register this EP via [`ExecutionProvider::register`] and detect + /// and handle any errors returned by that function. fn is_available(&self) -> Result { let mut providers: *mut *mut c_char = std::ptr::null_mut(); let mut num_providers = 0; @@ -110,56 +111,50 @@ pub enum ArenaExtendStrategy { SameAsRequested } -/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more -/// info on execution providers. Execution providers are actually registered via the functions [`crate::SessionBuilder`] -/// (per-session) or [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder) (default for all sessions in an -/// environment). -#[derive(Debug, Clone)] +/// Dynamic execution provider container, used to provide a list of multiple types of execution providers when +/// configuring execution providers for a [`SessionBuilder`](crate::SessionBuilder) or +/// [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder). +/// +/// See [`ExecutionProvider`] for more info on execution providers. +#[derive(Clone)] #[allow(missing_docs)] #[non_exhaustive] -pub enum ExecutionProviderDispatch { - CPU(CPUExecutionProvider), - CUDA(CUDAExecutionProvider), - TensorRT(TensorRTExecutionProvider), - OpenVINO(OpenVINOExecutionProvider), - ACL(ACLExecutionProvider), - OneDNN(OneDNNExecutionProvider), - CoreML(CoreMLExecutionProvider), - DirectML(DirectMLExecutionProvider), - ROCm(ROCmExecutionProvider), - NNAPI(NNAPIExecutionProvider), - QNN(QNNExecutionProvider), - TVM(TVMExecutionProvider), - CANN(CANNExecutionProvider), - XNNPACK(XNNPACKExecutionProvider), - ArmNN(ArmNNExecutionProvider) +pub struct ExecutionProviderDispatch { + pub(crate) inner: Arc, + error_on_failure: bool } -macro_rules! impl_dispatch { - ($($variant:ident),*) => { - impl ExecutionProvider for ExecutionProviderDispatch { - fn as_str(&self) -> &'static str { - match self { - $(Self::$variant(inner) => inner.as_str(),)* - } - } +impl ExecutionProviderDispatch { + pub(crate) fn new(ep: E) -> Self { + ExecutionProviderDispatch { + inner: Arc::new(ep) as Arc, + error_on_failure: false + } + } - fn is_available(&self) -> $crate::Result { - match self { - $(Self::$variant(inner) => inner.is_available(),)* - } - } + /// Configures this execution provider to silently log an error if registration of the EP fails. + /// This is the default behavior; it can be overridden with [`ExecutionProviderDispatch::error_on_failure`]. + pub fn fail_silently(mut self) -> Self { + self.error_on_failure = false; + self + } - fn register(&self, session_builder: &$crate::SessionBuilder) -> $crate::Result<()> { - match self { - $(Self::$variant(inner) => inner.register(session_builder),)* - } - } - } - }; + /// Configures this execution provider to return an error upon EP registration if registration of this EP fails. + /// The default behavior is to silently fail and fall back to the next execution provider, or the CPU provider if no + /// registrations succeed. + pub fn error_on_failure(mut self) -> Self { + self.error_on_failure = true; + self + } } -impl_dispatch!(CPU, CUDA, TensorRT, ACL, OneDNN, OpenVINO, CoreML, CANN, ROCm, DirectML, TVM, NNAPI, QNN, XNNPACK, ArmNN); +impl Debug for ExecutionProviderDispatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(self.inner.as_str()) + .field("error_on_failure", &self.error_on_failure) + .finish() + } +} #[allow(unused)] macro_rules! map_keys { @@ -207,26 +202,31 @@ macro_rules! get_ep_register { pub(crate) use get_ep_register; #[tracing::instrument(skip_all)] -pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) { +pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) -> Result<()> { let execution_providers: Vec<_> = execution_providers.collect(); let mut fallback_to_cpu = !execution_providers.is_empty(); for ex in execution_providers { - if let Err(e) = ex.register(session_builder) { + if let Err(e) = ex.inner.register(session_builder) { + if ex.error_on_failure { + return Err(e); + } + if let &Error::ExecutionProviderNotRegistered(ep_name) = &e { - if ex.supported_by_platform() { + if ex.inner.supported_by_platform() { tracing::warn!("{e}"); } else { - tracing::debug!("{e} (additionally, `{ep_name}` is not supported on this platform)"); + tracing::debug!("{e} (note: additionally, `{ep_name}` is not supported on this platform)"); } } else { - tracing::warn!("An error occurred when attempting to register `{}`: {e}", ex.as_str()); + tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str()); } } else { - tracing::info!("Successfully registered `{}`", ex.as_str()); + tracing::info!("Successfully registered `{}`", ex.inner.as_str()); fallback_to_cpu = false; } } if fallback_to_cpu { tracing::warn!("No execution providers registered successfully. Falling back to CPU."); } + Ok(()) } diff --git a/src/execution_providers/nnapi.rs b/src/execution_providers/nnapi.rs index 472db339..9f1951ef 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/execution_providers/nnapi.rs @@ -59,7 +59,7 @@ impl NNAPIExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: NNAPIExecutionProvider) -> Self { - ExecutionProviderDispatch::NNAPI(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/onednn.rs b/src/execution_providers/onednn.rs index 04166757..795d0e66 100644 --- a/src/execution_providers/onednn.rs +++ b/src/execution_providers/onednn.rs @@ -29,7 +29,7 @@ impl OneDNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OneDNNExecutionProvider) -> Self { - ExecutionProviderDispatch::OneDNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index fb8f932b..95dc8e26 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -103,7 +103,7 @@ impl OpenVINOExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OpenVINOExecutionProvider) -> Self { - ExecutionProviderDispatch::OpenVINO(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index 6262aac3..eb7075d5 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -110,7 +110,7 @@ impl QNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: QNNExecutionProvider) -> Self { - ExecutionProviderDispatch::QNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index c2c28857..be4cfdea 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -114,7 +114,7 @@ impl ROCmExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ROCmExecutionProvider) -> Self { - ExecutionProviderDispatch::ROCm(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index fe581c34..e60e16f0 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -210,7 +210,7 @@ impl TensorRTExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TensorRTExecutionProvider) -> Self { - ExecutionProviderDispatch::TensorRT(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index a054a704..19c8ea7a 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -54,7 +54,7 @@ impl TVMExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TVMExecutionProvider) -> Self { - ExecutionProviderDispatch::TVM(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/xnnpack.rs b/src/execution_providers/xnnpack.rs index b344cc3b..87933260 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/execution_providers/xnnpack.rs @@ -23,7 +23,7 @@ impl XNNPACKExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: XNNPACKExecutionProvider) -> Self { - ExecutionProviderDispatch::XNNPACK(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/session/builder.rs b/src/session/builder.rs index 60f716e0..7d654c2a 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -112,7 +112,7 @@ impl SessionBuilder { /// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by /// providing a `Vec` of [`ExecutionProviderDispatch`]es. pub fn with_execution_providers(self, execution_providers: impl IntoIterator) -> Result { - apply_execution_providers(&self, execution_providers.into_iter()); + apply_execution_providers(&self, execution_providers.into_iter())?; Ok(self) } @@ -329,7 +329,7 @@ impl SessionBuilder { .collect(); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; @@ -406,7 +406,7 @@ impl SessionBuilder { let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];