Skip to content

Commit

Permalink
feat: option to error out session builder if EP registration fails
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jun 14, 2024
1 parent 812fdb0 commit 23fce78
Show file tree
Hide file tree
Showing 19 changed files with 89 additions and 74 deletions.
17 changes: 16 additions & 1 deletion docs/pages/perf/execution-providers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,22 @@ fn main() -> anyhow::Result<()> {
## Fallback behavior
`ort` will silently fail and fall back to executing on the CPU if all execution providers fail to register. In many cases, though, you'll want to show the user an error message when an EP fails to register, or outright abort the process.

To receive these registration errors, instead use `ExecutionProvider::register` to register an execution provider:
You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`:
```rust
use ort::{CoreMLExecutionProvider, Session};

fn main() -> anyhow::Result<()> {
let session = Session::builder()?
.with_execution_providers([
CUDAExecutionProvider::default().build().error_on_failure()
])?
.commit_from_file("model.onnx")?;

Ok(())
}
```

If you require more complex error handling, you can also manually register execution providers via the `ExecutionProvider::register` method:

```rust
use ort::{CUDAExecutionProvider, ExecutionProvider, Session};
Expand Down
2 changes: 1 addition & 1 deletion examples/cudarc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();

ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])
.commit()?;

let model =
Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/acl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl ACLExecutionProvider {

impl From<ACLExecutionProvider> for ExecutionProviderDispatch {
fn from(value: ACLExecutionProvider) -> Self {
ExecutionProviderDispatch::ACL(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/armnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl ArmNNExecutionProvider {

impl From<ArmNNExecutionProvider> for ExecutionProviderDispatch {
fn from(value: ArmNNExecutionProvider) -> Self {
ExecutionProviderDispatch::ArmNN(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/cann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl CANNExecutionProvider {

impl From<CANNExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CANNExecutionProvider) -> Self {
ExecutionProviderDispatch::CANN(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/coreml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl CoreMLExecutionProvider {

impl From<CoreMLExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CoreMLExecutionProvider) -> Self {
ExecutionProviderDispatch::CoreML(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl CPUExecutionProvider {

impl From<CPUExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CPUExecutionProvider) -> Self {
ExecutionProviderDispatch::CPU(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl CUDAExecutionProvider {

impl From<CUDAExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CUDAExecutionProvider) -> Self {
ExecutionProviderDispatch::CUDA(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/directml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl DirectMLExecutionProvider {

impl From<DirectMLExecutionProvider> for ExecutionProviderDispatch {
fn from(value: DirectMLExecutionProvider) -> Self {
ExecutionProviderDispatch::DirectML(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
108 changes: 54 additions & 54 deletions src/execution_providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Debug, os::raw::c_char};
use std::{fmt::Debug, os::raw::c_char, sync::Arc};

use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder};

Expand Down Expand Up @@ -60,16 +60,17 @@ pub trait ExecutionProvider {
true
}

/// Returns `Ok(true)` if ONNX Runtime was compiled with support for this execution provider, and `Ok(false)`
/// Returns `Ok(true)` if ONNX Runtime was *compiled with support* for this execution provider, and `Ok(false)`
/// otherwise.
///
/// An `Err` may be returned if a serious internal error occurs, in which case your application should probably
/// just abort.
///
/// Note that this does not always mean the execution provider is *usable* for a specific model. A model may use
/// operators not supported by an execution provider, or the EP may encounter an error while attempting to load a
/// dynamic library during registration. In most cases (i.e. showing the user an error message if CUDA could not be
/// enabled), you'll instead want to detect and handle errors from [`ExecutionProvider::register`].
/// **Note that this does not always mean the execution provider is *usable* for a specific session.** A model may
/// use operators not supported by an execution provider, or the EP may encounter an error while attempting to load
/// dependencies during session creation. In most cases (i.e. showing the user an error message if CUDA could not be
/// enabled), you'll instead want to manually register this EP via [`ExecutionProvider::register`] and detect
/// and handle any errors returned by that function.
fn is_available(&self) -> Result<bool> {
let mut providers: *mut *mut c_char = std::ptr::null_mut();
let mut num_providers = 0;
Expand Down Expand Up @@ -110,56 +111,50 @@ pub enum ArenaExtendStrategy {
SameAsRequested
}

/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more
/// info on execution providers. Execution providers are actually registered via the functions [`crate::SessionBuilder`]
/// (per-session) or [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder) (default for all sessions in an
/// environment).
#[derive(Debug, Clone)]
/// Dynamic execution provider container, used to provide a list of multiple types of execution providers when
/// configuring execution providers for a [`SessionBuilder`](crate::SessionBuilder) or
/// [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder).
///
/// See [`ExecutionProvider`] for more info on execution providers.
#[derive(Clone)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum ExecutionProviderDispatch {
CPU(CPUExecutionProvider),
CUDA(CUDAExecutionProvider),
TensorRT(TensorRTExecutionProvider),
OpenVINO(OpenVINOExecutionProvider),
ACL(ACLExecutionProvider),
OneDNN(OneDNNExecutionProvider),
CoreML(CoreMLExecutionProvider),
DirectML(DirectMLExecutionProvider),
ROCm(ROCmExecutionProvider),
NNAPI(NNAPIExecutionProvider),
QNN(QNNExecutionProvider),
TVM(TVMExecutionProvider),
CANN(CANNExecutionProvider),
XNNPACK(XNNPACKExecutionProvider),
ArmNN(ArmNNExecutionProvider)
pub struct ExecutionProviderDispatch {
pub(crate) inner: Arc<dyn ExecutionProvider>,
error_on_failure: bool
}

macro_rules! impl_dispatch {
($($variant:ident),*) => {
impl ExecutionProvider for ExecutionProviderDispatch {
fn as_str(&self) -> &'static str {
match self {
$(Self::$variant(inner) => inner.as_str(),)*
}
}
impl ExecutionProviderDispatch {
pub(crate) fn new<E: ExecutionProvider + 'static>(ep: E) -> Self {
ExecutionProviderDispatch {
inner: Arc::new(ep) as Arc<dyn ExecutionProvider>,
error_on_failure: false
}
}

fn is_available(&self) -> $crate::Result<bool> {
match self {
$(Self::$variant(inner) => inner.is_available(),)*
}
}
/// Configures this execution provider to silently log an error if registration of the EP fails.
/// This is the default behavior; it can be overridden with [`ExecutionProviderDispatch::error_on_failure`].
pub fn fail_silently(mut self) -> Self {
self.error_on_failure = false;
self
}

fn register(&self, session_builder: &$crate::SessionBuilder) -> $crate::Result<()> {
match self {
$(Self::$variant(inner) => inner.register(session_builder),)*
}
}
}
};
/// Configures this execution provider to return an error upon EP registration if registration of this EP fails.
/// The default behavior is to silently fail and fall back to the next execution provider, or the CPU provider if no
/// registrations succeed.
pub fn error_on_failure(mut self) -> Self {
self.error_on_failure = true;
self
}
}

impl_dispatch!(CPU, CUDA, TensorRT, ACL, OneDNN, OpenVINO, CoreML, CANN, ROCm, DirectML, TVM, NNAPI, QNN, XNNPACK, ArmNN);
impl Debug for ExecutionProviderDispatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(self.inner.as_str())
.field("error_on_failure", &self.error_on_failure)
.finish()
}
}

#[allow(unused)]
macro_rules! map_keys {
Expand Down Expand Up @@ -207,26 +202,31 @@ macro_rules! get_ep_register {
pub(crate) use get_ep_register;

#[tracing::instrument(skip_all)]
pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator<Item = ExecutionProviderDispatch>) {
pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator<Item = ExecutionProviderDispatch>) -> Result<()> {
let execution_providers: Vec<_> = execution_providers.collect();
let mut fallback_to_cpu = !execution_providers.is_empty();
for ex in execution_providers {
if let Err(e) = ex.register(session_builder) {
if let Err(e) = ex.inner.register(session_builder) {
if ex.error_on_failure {
return Err(e);
}

if let &Error::ExecutionProviderNotRegistered(ep_name) = &e {
if ex.supported_by_platform() {
if ex.inner.supported_by_platform() {
tracing::warn!("{e}");
} else {
tracing::debug!("{e} (additionally, `{ep_name}` is not supported on this platform)");
tracing::debug!("{e} (note: additionally, `{ep_name}` is not supported on this platform)");
}
} else {
tracing::warn!("An error occurred when attempting to register `{}`: {e}", ex.as_str());
tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str());
}
} else {
tracing::info!("Successfully registered `{}`", ex.as_str());
tracing::info!("Successfully registered `{}`", ex.inner.as_str());
fallback_to_cpu = false;
}
}
if fallback_to_cpu {
tracing::warn!("No execution providers registered successfully. Falling back to CPU.");
}
Ok(())
}
2 changes: 1 addition & 1 deletion src/execution_providers/nnapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl NNAPIExecutionProvider {

impl From<NNAPIExecutionProvider> for ExecutionProviderDispatch {
fn from(value: NNAPIExecutionProvider) -> Self {
ExecutionProviderDispatch::NNAPI(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/onednn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl OneDNNExecutionProvider {

impl From<OneDNNExecutionProvider> for ExecutionProviderDispatch {
fn from(value: OneDNNExecutionProvider) -> Self {
ExecutionProviderDispatch::OneDNN(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/openvino.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl OpenVINOExecutionProvider {

impl From<OpenVINOExecutionProvider> for ExecutionProviderDispatch {
fn from(value: OpenVINOExecutionProvider) -> Self {
ExecutionProviderDispatch::OpenVINO(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/qnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl QNNExecutionProvider {

impl From<QNNExecutionProvider> for ExecutionProviderDispatch {
fn from(value: QNNExecutionProvider) -> Self {
ExecutionProviderDispatch::QNN(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/rocm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl ROCmExecutionProvider {

impl From<ROCmExecutionProvider> for ExecutionProviderDispatch {
fn from(value: ROCmExecutionProvider) -> Self {
ExecutionProviderDispatch::ROCm(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/tensorrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl TensorRTExecutionProvider {

impl From<TensorRTExecutionProvider> for ExecutionProviderDispatch {
fn from(value: TensorRTExecutionProvider) -> Self {
ExecutionProviderDispatch::TensorRT(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/tvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl TVMExecutionProvider {

impl From<TVMExecutionProvider> for ExecutionProviderDispatch {
fn from(value: TVMExecutionProvider) -> Self {
ExecutionProviderDispatch::TVM(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_providers/xnnpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl XNNPACKExecutionProvider {

impl From<XNNPACKExecutionProvider> for ExecutionProviderDispatch {
fn from(value: XNNPACKExecutionProvider) -> Self {
ExecutionProviderDispatch::XNNPACK(value)
ExecutionProviderDispatch::new(value)
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl SessionBuilder {
/// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by
/// providing a `Vec` of [`ExecutionProviderDispatch`]es.
pub fn with_execution_providers(self, execution_providers: impl IntoIterator<Item = ExecutionProviderDispatch>) -> Result<Self> {
apply_execution_providers(&self, execution_providers.into_iter());
apply_execution_providers(&self, execution_providers.into_iter())?;
Ok(self)
}

Expand Down Expand Up @@ -329,7 +329,7 @@ impl SessionBuilder {
.collect();

let env = get_environment()?;
apply_execution_providers(&self, env.execution_providers.iter().cloned());
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
Expand Down Expand Up @@ -406,7 +406,7 @@ impl SessionBuilder {
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();

let env = get_environment()?;
apply_execution_providers(&self, env.execution_providers.iter().cloned());
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
Expand Down

0 comments on commit 23fce78

Please sign in to comment.