diff --git a/src/environment.rs b/src/environment.rs index 338b60ec..8aabfa57 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,11 +1,7 @@ use std::{ - cell::UnsafeCell, ffi::{self, CStr, CString}, - ptr, - sync::{ - atomic::{AtomicPtr, Ordering}, - Arc - } + ptr::{self, NonNull}, + sync::{Arc, RwLock} }; use ort_sys::c_char; @@ -20,12 +16,12 @@ use crate::{ }; struct EnvironmentSingleton { - cell: UnsafeCell>> + lock: RwLock>> } unsafe impl Sync for EnvironmentSingleton {} -static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) }; +static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(None) }; /// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created. /// @@ -41,14 +37,14 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::ne #[derive(Debug)] pub struct Environment { pub(crate) execution_providers: Vec, - pub(crate) env_ptr: AtomicPtr, + pub(crate) env_ptr: NonNull, pub(crate) has_global_threadpool: bool } impl Environment { /// Returns the underlying [`ort_sys::OrtEnv`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtEnv { - self.env_ptr.load(Ordering::Relaxed) + self.env_ptr.as_ptr() } } @@ -57,23 +53,22 @@ impl Drop for Environment { fn drop(&mut self) { debug!("Releasing environment"); - let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut(); - - assert_ne!(env_ptr, std::ptr::null_mut()); - ortsys![unsafe ReleaseEnv(env_ptr)]; + ortsys![unsafe ReleaseEnv(self.env_ptr.as_ptr())]; } } /// Gets a reference to the global environment, creating one if an environment has not been /// [`commit`](EnvironmentBuilder::commit)ted yet. -pub fn get_environment() -> Result<&'static Arc> { - if let Some(c) = unsafe { &*G_ENV.cell.get() } { - Ok(c) +pub fn get_environment() -> Result> { + let env = G_ENV.lock.read().expect("poisoned lock"); + if let Some(env) = env.as_ref() { + Ok(Arc::clone(env)) } else { - debug!("Environment not yet initialized, creating a new one"); - EnvironmentBuilder::new().commit()?; + // drop our read lock so we dont deadlock when `commit` takes a write lock + drop(env); - Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() }) + debug!("Environment not yet initialized, creating a new one"); + Ok(EnvironmentBuilder::new().commit()?) } } @@ -151,12 +146,7 @@ impl EnvironmentBuilder { } /// Commit the environment configuration and set the global environment. - pub fn commit(self) -> Result<()> { - // drop global reference to previous environment - if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } { - drop(env_arc); - } - + pub fn commit(self) -> Result> { let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); @@ -218,15 +208,20 @@ impl EnvironmentBuilder { ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment]; } - unsafe { - *G_ENV.cell.get() = Some(Arc::new(Environment { - execution_providers: self.execution_providers, - env_ptr: AtomicPtr::new(env_ptr), - has_global_threadpool - })); - }; - - Ok(()) + let mut env_lock = G_ENV.lock.write().expect("poisoned lock"); + // drop global reference to previous environment + if let Some(env_arc) = env_lock.take() { + drop(env_arc); + } + let env = Arc::new(Environment { + execution_providers: self.execution_providers, + // we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call + env_ptr: unsafe { NonNull::new_unchecked(env_ptr) }, + has_global_threadpool + }); + env_lock.replace(Arc::clone(&env)); + + Ok(env) } } @@ -316,16 +311,11 @@ mod tests { use super::*; fn is_env_initialized() -> bool { - unsafe { (*G_ENV.cell.get()).as_ref() }.is_some() - && !unsafe { (*G_ENV.cell.get()).as_ref() } - .unwrap_or_else(|| unreachable!()) - .env_ptr - .load(Ordering::Relaxed) - .is_null() + G_ENV.lock.read().expect("poisoned lock").is_some() } fn env_ptr() -> Option<*mut ort_sys::OrtEnv> { - unsafe { (*G_ENV.cell.get()).as_ref() }.map(|f| f.env_ptr.load(Ordering::Relaxed)) + (*G_ENV.lock.read().expect("poisoned lock")).as_ref().map(|f| f.env_ptr.as_ptr()) } struct ConcurrentTestRun { diff --git a/src/session/builder.rs b/src/session/builder.rs index f650db54..2fa02fc1 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -8,7 +8,7 @@ use std::{ path::Path, ptr::{self, NonNull}, rc::Rc, - sync::{atomic::Ordering, Arc} + sync::Arc }; use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; @@ -312,10 +312,8 @@ impl SessionBuilder { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; } - let env_ptr = env.env_ptr.load(Ordering::Relaxed); - let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); - ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; + ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; @@ -348,7 +346,7 @@ impl SessionBuilder { session_ptr, allocator, _extras: extras, - _environment: Arc::clone(env) + _environment: env }), inputs, outputs @@ -389,12 +387,10 @@ impl SessionBuilder { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; } - let env_ptr = env.env_ptr.load(Ordering::Relaxed); - let model_data = model_bytes.as_ptr().cast::(); let model_data_length = model_bytes.len(); ortsys![ - unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; + unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr) ]; @@ -429,7 +425,7 @@ impl SessionBuilder { session_ptr, allocator, _extras: extras, - _environment: Arc::clone(env) + _environment: env }), inputs, outputs