Skip to content

Commit

Permalink
feat: get overridable graph initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 17, 2024
1 parent b58595c commit 6de6aa5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ pub use self::operator::{
InferShapeFn, Operator, OperatorDomain
};
pub use self::session::{
GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions,
SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner
GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, OverridableInitializer,
RunOptions, SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner
};
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
Expand Down
48 changes: 47 additions & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
//! Contains the [`Session`] and [`SessionBuilder`] types for managing ONNX Runtime sessions and performing inference.
use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc};
use std::{
any::Any,
ffi::{CStr, CString},
marker::PhantomData,
ops::Deref,
os::raw::c_char,
ptr::NonNull,
sync::Arc
};

use crate::{
char_p_to_string,
Expand Down Expand Up @@ -142,6 +150,28 @@ impl Session {
Arc::clone(&self.inner)
}

/// Returns a list of initializers which are overridable (i.e. also graph inputs).
#[must_use]
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
// can only fail if:
// - index is out of bounds (impossible because of the loop)
// - the model is not loaded (how could this even be possible?)
let mut size = 0;
ortsys![unsafe SessionGetOverridableInitializerCount(self.ptr(), &mut size).expect("infallible")];
let allocator = Allocator::default();
(0..size)
.map(|i| {
let mut name: *mut c_char = std::ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerName(self.ptr(), i, allocator.ptr.as_ptr(), &mut name).expect("infallible")];
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy().into_owned();
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerTypeInfo(self.ptr(), i, &mut typeinfo_ptr).expect("infallible")];
let dtype = ValueType::from_type_info(typeinfo_ptr);
OverridableInitializer { name, dtype }
})
.collect()
}

/// Run input data through the ONNX graph, performing inference.
///
/// See [`crate::inputs!`] for a convenient macro which will help you create your session inputs from `ndarray`s or
Expand Down Expand Up @@ -455,6 +485,22 @@ unsafe impl Send for Session {}
// temporary bug in ONNX Runtime or a wontfix. Maybe this impl should be removed just to be safe?
unsafe impl Sync for Session {}

#[derive(Debug, Clone)]
pub struct OverridableInitializer {
name: String,
dtype: ValueType
}

impl OverridableInitializer {
pub fn name(&self) -> &str {
&self.name
}

pub fn dtype(&self) -> &ValueType {
&self.dtype
}
}

mod dangerous {
use super::*;

Expand Down

0 comments on commit 6de6aa5

Please sign in to comment.