Skip to content

Commit

Permalink
fix: thread-safe environment initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Aug 29, 2024
1 parent 9f4527c commit edcb219
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 51 deletions.
74 changes: 32 additions & 42 deletions src/environment.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use std::{
cell::UnsafeCell,
ffi::{self, CStr, CString},
ptr,
sync::{
atomic::{AtomicPtr, Ordering},
Arc
}
ptr::{self, NonNull},
sync::{Arc, RwLock}
};

use ort_sys::c_char;
Expand All @@ -20,12 +16,12 @@ use crate::{
};

struct EnvironmentSingleton {
cell: UnsafeCell<Option<Arc<Environment>>>
lock: RwLock<Option<Arc<Environment>>>
}

unsafe impl Sync for EnvironmentSingleton {}

static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(None) };

/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
///
Expand All @@ -41,14 +37,14 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::ne
#[derive(Debug)]
pub struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>,
pub(crate) env_ptr: NonNull<ort_sys::OrtEnv>,
pub(crate) has_global_threadpool: bool
}

impl Environment {
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
self.env_ptr.load(Ordering::Relaxed)
self.env_ptr.as_ptr()
}
}

Expand All @@ -57,23 +53,22 @@ impl Drop for Environment {
fn drop(&mut self) {
debug!("Releasing environment");

let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut();

assert_ne!(env_ptr, std::ptr::null_mut());
ortsys![unsafe ReleaseEnv(env_ptr)];
ortsys![unsafe ReleaseEnv(self.env_ptr.as_ptr())];
}
}

/// Gets a reference to the global environment, creating one if an environment has not been
/// [`commit`](EnvironmentBuilder::commit)ted yet.
pub fn get_environment() -> Result<&'static Arc<Environment>> {
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
Ok(c)
pub fn get_environment() -> Result<Arc<Environment>> {
let env = G_ENV.lock.read().expect("poisoned lock");
if let Some(env) = env.as_ref() {
Ok(Arc::clone(env))
} else {
debug!("Environment not yet initialized, creating a new one");
EnvironmentBuilder::new().commit()?;
// drop our read lock so we dont deadlock when `commit` takes a write lock
drop(env);

Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
debug!("Environment not yet initialized, creating a new one");
Ok(EnvironmentBuilder::new().commit()?)
}
}

Expand Down Expand Up @@ -151,12 +146,7 @@ impl EnvironmentBuilder {
}

/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
// drop global reference to previous environment
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
drop(env_arc);
}

pub fn commit(self) -> Result<Arc<Environment>> {
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
Expand Down Expand Up @@ -218,15 +208,20 @@ impl EnvironmentBuilder {
ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment];
}

unsafe {
*G_ENV.cell.get() = Some(Arc::new(Environment {
execution_providers: self.execution_providers,
env_ptr: AtomicPtr::new(env_ptr),
has_global_threadpool
}));
};

Ok(())
let mut env_lock = G_ENV.lock.write().expect("poisoned lock");
// drop global reference to previous environment
if let Some(env_arc) = env_lock.take() {
drop(env_arc);
}
let env = Arc::new(Environment {
execution_providers: self.execution_providers,
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
env_ptr: unsafe { NonNull::new_unchecked(env_ptr) },
has_global_threadpool
});
env_lock.replace(Arc::clone(&env));

Ok(env)
}
}

Expand Down Expand Up @@ -316,16 +311,11 @@ mod tests {
use super::*;

fn is_env_initialized() -> bool {
unsafe { (*G_ENV.cell.get()).as_ref() }.is_some()
&& !unsafe { (*G_ENV.cell.get()).as_ref() }
.unwrap_or_else(|| unreachable!())
.env_ptr
.load(Ordering::Relaxed)
.is_null()
G_ENV.lock.read().expect("poisoned lock").is_some()
}

fn env_ptr() -> Option<*mut ort_sys::OrtEnv> {
unsafe { (*G_ENV.cell.get()).as_ref() }.map(|f| f.env_ptr.load(Ordering::Relaxed))
(*G_ENV.lock.read().expect("poisoned lock")).as_ref().map(|f| f.env_ptr.as_ptr())
}

struct ConcurrentTestRun {
Expand Down
14 changes: 5 additions & 9 deletions src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
path::Path,
ptr::{self, NonNull},
rc::Rc,
sync::{atomic::Ordering, Arc}
sync::Arc
};

use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner};
Expand Down Expand Up @@ -312,10 +312,8 @@ impl SessionBuilder {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
}

let env_ptr = env.env_ptr.load(Ordering::Relaxed);

let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];

let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };

Expand Down Expand Up @@ -348,7 +346,7 @@ impl SessionBuilder {
session_ptr,
allocator,
_extras: extras,
_environment: Arc::clone(env)
_environment: env
}),
inputs,
outputs
Expand Down Expand Up @@ -389,12 +387,10 @@ impl SessionBuilder {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
}

let env_ptr = env.env_ptr.load(Ordering::Relaxed);

let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
nonNull(session_ptr)
];

Expand Down Expand Up @@ -429,7 +425,7 @@ impl SessionBuilder {
session_ptr,
allocator,
_extras: extras,
_environment: Arc::clone(env)
_environment: env
}),
inputs,
outputs
Expand Down

0 comments on commit edcb219

Please sign in to comment.