diff --git a/src/error.rs b/src/error.rs index 1eb32074..5f48c709 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,6 +38,11 @@ pub enum Error { /// Error occurred when creating ONNX session options. #[error("Failed to create ONNX Runtime session options: {0}")] CreateSessionOptions(ErrorInternal), + /// Failed to enable `onnxruntime-extensions` for session. + #[error("Failed to enable `onnxruntime-extensions`: {0}")] + EnableExtensions(ErrorInternal), + #[error("Failed to add configuration entry to session builder: {0}")] + AddSessionConfigEntry(ErrorInternal), /// Error occurred when creating an allocator from a [`crate::MemoryInfo`] struct while building a session. #[error("Failed to create allocator from memory info: {0}")] CreateAllocator(ErrorInternal), diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs new file mode 100644 index 00000000..e28973f5 --- /dev/null +++ b/src/session/builder/impl_commit.rs @@ -0,0 +1,203 @@ +#[cfg(feature = "fetch-models")] +use std::fmt::Write; +use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc}; + +use super::SessionBuilder; +#[cfg(feature = "fetch-models")] +use crate::error::FetchModelError; +use crate::{ + environment::get_environment, + error::{Error, Result}, + execution_providers::apply_execution_providers, + memory::Allocator, + ortsys, + session::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner} +}; + +impl SessionBuilder { + /// Downloads a pre-trained ONNX model from the given URL and builds the session. + #[cfg(feature = "fetch-models")] + #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] + pub fn commit_from_url(self, model_url: impl AsRef) -> Result { + let mut download_dir = ort_sys::internal::dirs::cache_dir() + .expect("could not determine cache directory") + .join("models"); + if std::fs::create_dir_all(&download_dir).is_err() { + download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); + } + + let url = model_url.as_ref(); + let model_filename = ::digest(url).into_iter().fold(String::new(), |mut s, b| { + let _ = write!(&mut s, "{:02x}", b); + s + }); + let model_filepath = download_dir.join(model_filename); + let downloaded_path = if model_filepath.exists() { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); + model_filepath + } else { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); + + let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?; + + let len = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .expect("Missing Content-Length header"); + tracing::info!(len, "Downloading {} bytes", len); + + let mut reader = resp.into_reader(); + + let f = std::fs::File::create(&model_filepath).expect("Failed to create model file"); + let mut writer = std::io::BufWriter::new(f); + + let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?; + if bytes_io_count == len as u64 { + model_filepath + } else { + return Err(FetchModelError::CopyError { + expected: len as u64, + io: bytes_io_count + } + .into()); + } + }; + + self.commit_from_file(downloaded_path) + } + + /// Loads an ONNX model from a file and builds the session. + pub fn commit_from_file

(mut self, model_filepath_ref: P) -> Result + where + P: AsRef + { + let model_filepath = model_filepath_ref.as_ref(); + if !model_filepath.exists() { + return Err(Error::FileDoesNotExist { + filename: model_filepath.to_path_buf() + }); + } + + let model_path = crate::util::path_to_os_char(model_filepath); + + let env = get_environment()?; + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; + + if env.has_global_threadpool { + ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; + } + + let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); + ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; + + let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; + + let allocator = match &self.memory_info { + Some(info) => { + let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); + ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; + unsafe { Allocator::from_raw_unchecked(allocator_ptr) } + } + None => Allocator::default() + }; + + // Extract input and output properties + let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; + let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; + let inputs = (0..num_input_nodes) + .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) + .collect::>>()?; + + let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box); + #[cfg(feature = "operator-libraries")] + let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box)); + let extras: Vec> = extras.collect(); + + Ok(Session { + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _extras: extras, + _environment: env + }), + inputs, + outputs + }) + } + + /// Load an ONNX graph from memory and commit the session + /// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. + /// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). + /// + /// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that + /// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). + pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result> { + // Enable zero-copy deserialization for models in `.ort` format. + self.add_config_entry("session.use_ort_model_bytes_directly", "1")?; + self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?; + + let session = self.commit_from_memory(model_bytes)?; + + Ok(InMemorySession { session, phantom: PhantomData }) + } + + /// Load an ONNX graph from memory and commit the session. + pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result { + let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); + + let env = get_environment()?; + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; + + if env.has_global_threadpool { + ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; + } + + let model_data = model_bytes.as_ptr().cast::(); + let model_data_length = model_bytes.len(); + ortsys![ + unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; + nonNull(session_ptr) + ]; + + let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; + + let allocator = match &self.memory_info { + Some(info) => { + let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); + ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; + unsafe { Allocator::from_raw_unchecked(allocator_ptr) } + } + None => Allocator::default() + }; + + // Extract input and output properties + let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; + let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; + let inputs = (0..num_input_nodes) + .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) + .collect::>>()?; + + let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box); + #[cfg(feature = "operator-libraries")] + let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box)); + let extras: Vec> = extras.collect(); + + let session = Session { + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _extras: extras, + _environment: env + }), + inputs, + outputs + }; + Ok(session) + } +} diff --git a/src/session/builder/impl_config_keys.rs b/src/session/builder/impl_config_keys.rs new file mode 100644 index 00000000..e610f419 --- /dev/null +++ b/src/session/builder/impl_config_keys.rs @@ -0,0 +1,100 @@ +use super::SessionBuilder; +use crate::Result; + +// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h + +impl SessionBuilder { + /// Enable/disable the usage of prepacking. + /// + /// This option is **enabled** by default. + pub fn with_prepacking(mut self, enable: bool) -> Result { + self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?; + Ok(self) + } + + /// Use allocators from the registered environment. + /// + /// This option is **disabled** by default. + pub fn with_env_allocators(mut self) -> Result { + self.add_config_entry("session.use_env_allocators", "1")?; + Ok(self) + } + + /// Enable flush-to-zero and denormal-as-zero. + /// + /// This option is **disabled** by default, as it may hurt model accuracy. + pub fn with_denormal_as_zero(mut self) -> Result { + self.add_config_entry("session.set_denormal_as_zero", "1")?; + Ok(self) + } + + /// Enable/disable fusion for quantized models in QDQ (QuantizeLinear/DequantizeLinear) format. + /// + /// This option is **enabled** by default for all EPs except DirectML. + pub fn with_quant_qdq(mut self, enable: bool) -> Result { + self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?; + Ok(self) + } + + /// Enable/disable the optimization step removing double QDQ nodes. + /// + /// This option is **enabled** by default. + pub fn with_double_qdq_remover(mut self, enable: bool) -> Result { + self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?; + Ok(self) + } + + /// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed. + /// + /// This option is **disabled** by default. + pub fn with_qdq_cleanup(mut self) -> Result { + self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?; + Ok(self) + } + + /// Enable fast GELU approximation. + /// + /// This option is **disabled** by default, as it may hurt accuracy. + pub fn with_approximate_gelu(mut self) -> Result { + self.add_config_entry("optimization.enable_gelu_approximation", "1")?; + Ok(self) + } + + /// Enable/disable ahead-of-time function inlining. + /// + /// This option is **enabled** by default. + pub fn with_aot_inlining(mut self, enable: bool) -> Result { + self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?; + Ok(self) + } + + /// Accepts a comma-separated list of optimizers to disable. + pub fn with_disabled_optimizers(mut self, optimizers: &str) -> Result { + self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?; + Ok(self) + } + + /// Enable using device allocator for allocating initialized tensor memory. + /// + /// This option is **disabled** by default. + pub fn with_device_allocator_for_initializers(mut self) -> Result { + self.add_config_entry("session.use_device_allocator_for_initializers", "1")?; + Ok(self) + } + + /// Enable/disable allowing the inter-op threads to spin for a short period before blocking. + /// + /// This option is **enabled** by defualt. + pub fn with_inter_op_spinning(mut self, enable: bool) -> Result { + self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?; + Ok(self) + } + + /// Enable/disable allowing the intra-op threads to spin for a short period before blocking. + /// + /// This option is **enabled** by defualt. + pub fn with_intra_op_spinning(mut self, enable: bool) -> Result { + self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?; + Ok(self) + } +} diff --git a/src/session/builder.rs b/src/session/builder/impl_options.rs similarity index 55% rename from src/session/builder.rs rename to src/session/builder/impl_options.rs index 2fa02fc1..6e48b957 100644 --- a/src/session/builder.rs +++ b/src/session/builder/impl_options.rs @@ -1,99 +1,13 @@ -#[cfg(any(feature = "operator-libraries", not(windows)))] -use std::ffi::CString; -#[cfg(feature = "fetch-models")] -use std::fmt::Write; -use std::{ - any::Any, - marker::PhantomData, - path::Path, - ptr::{self, NonNull}, - rc::Rc, - sync::Arc -}; +use std::{rc::Rc, sync::Arc}; -use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; -#[cfg(feature = "fetch-models")] -use crate::error::FetchModelError; +use super::SessionBuilder; use crate::{ - environment::get_environment, - error::{assert_non_null_pointer, status_to_result, Error, Result}, + error::{Error, Result}, execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, - memory::{Allocator, MemoryInfo}, - operator::OperatorDomain, - ortsys + ortsys, MemoryInfo, OperatorDomain }; -/// Creates a session using the builder pattern. -/// -/// Once configured, use the [`SessionBuilder::commit_from_file`](crate::SessionBuilder::commit_from_file) -/// method to 'commit' the builder configuration into a [`Session`]. -/// -/// ``` -/// # use ort::{GraphOptimizationLevel, Session}; -/// # fn main() -> ort::Result<()> { -/// let session = Session::builder()? -/// .with_optimization_level(GraphOptimizationLevel::Level1)? -/// .with_intra_threads(1)? -/// .commit_from_file("tests/data/upsample.onnx")?; -/// # Ok(()) -/// # } -/// ``` -pub struct SessionBuilder { - pub(crate) session_options_ptr: NonNull, - memory_info: Option>, - #[cfg(feature = "operator-libraries")] - custom_runtime_handles: Vec>, - operator_domains: Vec> -} - -impl Clone for SessionBuilder { - fn clone(&self) -> Self { - let mut session_options_ptr = ptr::null_mut(); - status_to_result(ortsys![unsafe CloneSessionOptions(self.session_options_ptr.as_ptr(), ptr::addr_of_mut!(session_options_ptr))]) - .expect("error cloning session options"); - assert_non_null_pointer(session_options_ptr, "OrtSessionOptions").expect("Cloned session option pointer is null"); - Self { - session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, - memory_info: self.memory_info.clone(), - #[cfg(feature = "operator-libraries")] - custom_runtime_handles: self.custom_runtime_handles.clone(), - operator_domains: self.operator_domains.clone() - } - } -} - -impl Drop for SessionBuilder { - fn drop(&mut self) { - ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr.as_ptr())]; - } -} - impl SessionBuilder { - /// Creates a new session builder. - /// - /// ``` - /// # use ort::{GraphOptimizationLevel, Session}; - /// # fn main() -> ort::Result<()> { - /// let session = Session::builder()? - /// .with_optimization_level(GraphOptimizationLevel::Level1)? - /// .with_intra_threads(1)? - /// .commit_from_file("tests/data/upsample.onnx")?; - /// # Ok(()) - /// # } - /// ``` - pub fn new() -> Result { - let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut(); - ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> Error::CreateSessionOptions; nonNull(session_options_ptr)]; - - Ok(Self { - session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, - memory_info: None, - #[cfg(feature = "operator-libraries")] - custom_runtime_handles: Vec::new(), - operator_domains: Vec::new() - }) - } - /// Registers a list of execution providers for this session. Execution providers are registered in the order they /// are provided. /// @@ -203,6 +117,10 @@ impl SessionBuilder { #[cfg(feature = "operator-libraries")] #[cfg_attr(docsrs, doc(cfg(feature = "operator-libraries")))] pub fn with_operator_library(mut self, lib_path: impl AsRef) -> Result { + use std::ffi::CString; + + use crate::error::status_to_result; + let path_cstr = CString::new(lib_path.as_ref())?; let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut(); @@ -228,8 +146,7 @@ impl SessionBuilder { /// Enables [`onnxruntime-extensions`](https://github.com/microsoft/onnxruntime-extensions) custom operators. pub fn with_extensions(self) -> Result { - let status = ortsys![unsafe EnableOrtCustomOps(self.session_options_ptr.as_ptr())]; - status_to_result(status).map_err(Error::CreateSessionOptions)?; + ortsys![unsafe EnableOrtCustomOps(self.session_options_ptr.as_ptr()) -> Error::EnableExtensions]; Ok(self) } @@ -239,199 +156,6 @@ impl SessionBuilder { self.operator_domains.push(domain); Ok(self) } - - /// Downloads a pre-trained ONNX model from the given URL and builds the session. - #[cfg(feature = "fetch-models")] - #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] - pub fn commit_from_url(self, model_url: impl AsRef) -> Result { - let mut download_dir = ort_sys::internal::dirs::cache_dir() - .expect("could not determine cache directory") - .join("models"); - if std::fs::create_dir_all(&download_dir).is_err() { - download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); - } - - let url = model_url.as_ref(); - let model_filename = ::digest(url).into_iter().fold(String::new(), |mut s, b| { - let _ = write!(&mut s, "{:02x}", b); - s - }); - let model_filepath = download_dir.join(model_filename); - let downloaded_path = if model_filepath.exists() { - tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); - model_filepath - } else { - tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); - - let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?; - - let len = resp - .header("Content-Length") - .and_then(|s| s.parse::().ok()) - .expect("Missing Content-Length header"); - tracing::info!(len, "Downloading {} bytes", len); - - let mut reader = resp.into_reader(); - - let f = std::fs::File::create(&model_filepath).expect("Failed to create model file"); - let mut writer = std::io::BufWriter::new(f); - - let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?; - if bytes_io_count == len as u64 { - model_filepath - } else { - return Err(FetchModelError::CopyError { - expected: len as u64, - io: bytes_io_count - } - .into()); - } - }; - - self.commit_from_file(downloaded_path) - } - - /// Loads an ONNX model from a file and builds the session. - pub fn commit_from_file

(mut self, model_filepath_ref: P) -> Result - where - P: AsRef - { - let model_filepath = model_filepath_ref.as_ref(); - if !model_filepath.exists() { - return Err(Error::FileDoesNotExist { - filename: model_filepath.to_path_buf() - }); - } - - let model_path = crate::util::path_to_os_char(model_filepath); - - let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned())?; - - if env.has_global_threadpool { - ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; - } - - let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; - - let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; - - let allocator = match &self.memory_info { - Some(info) => { - let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); - ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; - unsafe { Allocator::from_raw_unchecked(allocator_ptr) } - } - None => Allocator::default() - }; - - // Extract input and output properties - let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; - let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; - let inputs = (0..num_input_nodes) - .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) - .collect::>>()?; - let outputs = (0..num_output_nodes) - .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) - .collect::>>()?; - - let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box); - #[cfg(feature = "operator-libraries")] - let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box)); - let extras: Vec> = extras.collect(); - - Ok(Session { - inner: Arc::new(SharedSessionInner { - session_ptr, - allocator, - _extras: extras, - _environment: env - }), - inputs, - outputs - }) - } - - /// Load an ONNX graph from memory and commit the session - /// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. - /// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). - /// - /// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that - /// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). - pub fn commit_from_memory_directly(self, model_bytes: &[u8]) -> Result> { - let str_to_char = |s: &str| { - s.as_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect::>() - }; - // Enable zero-copy deserialization for models in `.ort` format. - ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), str_to_char("session.use_ort_model_bytes_directly").as_ptr(), str_to_char("1").as_ptr())]; - ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), str_to_char("session.use_ort_model_bytes_for_initializers").as_ptr(), str_to_char("1").as_ptr())]; - - let session = self.commit_from_memory(model_bytes)?; - - Ok(InMemorySession { session, phantom: PhantomData }) - } - - /// Load an ONNX graph from memory and commit the session. - pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result { - let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - - let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned())?; - - if env.has_global_threadpool { - ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; - } - - let model_data = model_bytes.as_ptr().cast::(); - let model_data_length = model_bytes.len(); - ortsys![ - unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; - nonNull(session_ptr) - ]; - - let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; - - let allocator = match &self.memory_info { - Some(info) => { - let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); - ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; - unsafe { Allocator::from_raw_unchecked(allocator_ptr) } - } - None => Allocator::default() - }; - - // Extract input and output properties - let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; - let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; - let inputs = (0..num_input_nodes) - .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) - .collect::>>()?; - let outputs = (0..num_output_nodes) - .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) - .collect::>>()?; - - let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box); - #[cfg(feature = "operator-libraries")] - let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box)); - let extras: Vec> = extras.collect(); - - let session = Session { - inner: Arc::new(SharedSessionInner { - session_ptr, - allocator, - _extras: extras, - _environment: env - }), - inputs, - outputs - }; - Ok(session) - } } /// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially @@ -526,7 +250,7 @@ impl From for ort_sys::GraphOptimizationLevel { } #[cfg(feature = "operator-libraries")] -struct LibHandle(*mut std::os::raw::c_void); +pub(super) struct LibHandle(*mut std::os::raw::c_void); #[cfg(feature = "operator-libraries")] impl LibHandle { diff --git a/src/session/builder/mod.rs b/src/session/builder/mod.rs new file mode 100644 index 00000000..610024c7 --- /dev/null +++ b/src/session/builder/mod.rs @@ -0,0 +1,101 @@ +#[cfg(any(feature = "operator-libraries", not(windows)))] +use std::ffi::CString; +use std::{ + ptr::{self, NonNull}, + rc::Rc, + sync::Arc +}; + +use crate::{ + error::{assert_non_null_pointer, status_to_result, Error, Result}, + memory::MemoryInfo, + operator::OperatorDomain, + ortsys +}; + +mod impl_commit; +mod impl_config_keys; +mod impl_options; + +pub use self::impl_options::GraphOptimizationLevel; +#[cfg(feature = "operator-libraries")] +use self::impl_options::LibHandle; + +/// Creates a session using the builder pattern. +/// +/// Once configured, use the [`SessionBuilder::commit_from_file`](crate::SessionBuilder::commit_from_file) +/// method to 'commit' the builder configuration into a [`Session`]. +/// +/// ``` +/// # use ort::{GraphOptimizationLevel, Session}; +/// # fn main() -> ort::Result<()> { +/// let session = Session::builder()? +/// .with_optimization_level(GraphOptimizationLevel::Level1)? +/// .with_intra_threads(1)? +/// .commit_from_file("tests/data/upsample.onnx")?; +/// # Ok(()) +/// # } +/// ``` +pub struct SessionBuilder { + pub(crate) session_options_ptr: NonNull, + memory_info: Option>, + #[cfg(feature = "operator-libraries")] + custom_runtime_handles: Vec>, + operator_domains: Vec> +} + +impl Clone for SessionBuilder { + fn clone(&self) -> Self { + let mut session_options_ptr = ptr::null_mut(); + status_to_result(ortsys![unsafe CloneSessionOptions(self.session_options_ptr.as_ptr(), ptr::addr_of_mut!(session_options_ptr))]) + .expect("error cloning session options"); + assert_non_null_pointer(session_options_ptr, "OrtSessionOptions").expect("Cloned session option pointer is null"); + Self { + session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, + memory_info: self.memory_info.clone(), + #[cfg(feature = "operator-libraries")] + custom_runtime_handles: self.custom_runtime_handles.clone(), + operator_domains: self.operator_domains.clone() + } + } +} + +impl Drop for SessionBuilder { + fn drop(&mut self) { + ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr.as_ptr())]; + } +} + +impl SessionBuilder { + /// Creates a new session builder. + /// + /// ``` + /// # use ort::{GraphOptimizationLevel, Session}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()? + /// .with_optimization_level(GraphOptimizationLevel::Level1)? + /// .with_intra_threads(1)? + /// .commit_from_file("tests/data/upsample.onnx")?; + /// # Ok(()) + /// # } + /// ``` + pub fn new() -> Result { + let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = std::ptr::null_mut(); + ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> Error::CreateSessionOptions; nonNull(session_options_ptr)]; + + Ok(Self { + session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, + memory_info: None, + #[cfg(feature = "operator-libraries")] + custom_runtime_handles: Vec::new(), + operator_domains: Vec::new() + }) + } + + pub(crate) fn add_config_entry(&mut self, key: &str, value: &str) -> Result<()> { + let key = CString::new(key)?; + let value = CString::new(value)?; + ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr.as_ptr(), key.as_ptr(), value.as_ptr()) -> Error::AddSessionConfigEntry]; + Ok(()) + } +}