From 6de6aa5b60b664584deb8f572d9d06301429923f Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 17 Sep 2024 16:44:58 -0500 Subject: [PATCH] feat: get overridable graph initializers --- src/lib.rs | 4 ++-- src/session/mod.rs | 48 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3d18d34..ed8b9a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,8 +53,8 @@ pub use self::operator::{ InferShapeFn, Operator, OperatorDomain }; pub use self::session::{ - GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions, - SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner + GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, OverridableInitializer, + RunOptions, SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner }; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] diff --git a/src/session/mod.rs b/src/session/mod.rs index 7c7af56..6fcd9b7 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -1,6 +1,14 @@ //! Contains the [`Session`] and [`SessionBuilder`] types for managing ONNX Runtime sessions and performing inference. -use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; +use std::{ + any::Any, + ffi::{CStr, CString}, + marker::PhantomData, + ops::Deref, + os::raw::c_char, + ptr::NonNull, + sync::Arc +}; use crate::{ char_p_to_string, @@ -142,6 +150,28 @@ impl Session { Arc::clone(&self.inner) } + /// Returns a list of initializers which are overridable (i.e. also graph inputs). + #[must_use] + pub fn overridable_initializers(&self) -> Vec { + // can only fail if: + // - index is out of bounds (impossible because of the loop) + // - the model is not loaded (how could this even be possible?) + let mut size = 0; + ortsys![unsafe SessionGetOverridableInitializerCount(self.ptr(), &mut size).expect("infallible")]; + let allocator = Allocator::default(); + (0..size) + .map(|i| { + let mut name: *mut c_char = std::ptr::null_mut(); + ortsys![unsafe SessionGetOverridableInitializerName(self.ptr(), i, allocator.ptr.as_ptr(), &mut name).expect("infallible")]; + let name = unsafe { CStr::from_ptr(name) }.to_string_lossy().into_owned(); + let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); + ortsys![unsafe SessionGetOverridableInitializerTypeInfo(self.ptr(), i, &mut typeinfo_ptr).expect("infallible")]; + let dtype = ValueType::from_type_info(typeinfo_ptr); + OverridableInitializer { name, dtype } + }) + .collect() + } + /// Run input data through the ONNX graph, performing inference. /// /// See [`crate::inputs!`] for a convenient macro which will help you create your session inputs from `ndarray`s or @@ -455,6 +485,22 @@ unsafe impl Send for Session {} // temporary bug in ONNX Runtime or a wontfix. Maybe this impl should be removed just to be safe? unsafe impl Sync for Session {} +#[derive(Debug, Clone)] +pub struct OverridableInitializer { + name: String, + dtype: ValueType +} + +impl OverridableInitializer { + pub fn name(&self) -> &str { + &self.name + } + + pub fn dtype(&self) -> &ValueType { + &self.dtype + } +} + mod dangerous { use super::*;