Skip to content

Commit

Permalink
fix: enable compilation with ONNX 1.13 (#1835)
Browse files Browse the repository at this point in the history
Simple PR to fix a compilation error against ONNX runtime 1.13.1. Basically, the `Ort::Session::Get{Input,Output}Name` functions got removed starting from 1.13, and we have the use the `Ort::Session::Get{Input,Output}NameAllocated` versions which were introduced in 1.12. 

One implication is that this PR introduces a lower bound of 1.12.0 on the ONNX runtime version.
  • Loading branch information
gagnonlg authored Feb 10, 2023
1 parent 97a8458 commit c2ef676
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ clang_tidy:

build:
stage: build
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v30
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v34
tags:
- docker
variables:
Expand Down Expand Up @@ -106,7 +106,7 @@ test:
stage: test
needs:
- build
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v30
image: ghcr.io/acts-project/ubuntu2004_exatrkx:v34
tags:
- docker-gpu-nvidia
script:
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ set(_acts_doxygen_version 1.9.4)
set(_acts_eigen3_version 3.3.7)
set(_acts_hepmc3_version 3.2.1)
set(_acts_nlohmanjson_version 3.2.0)
set(_acts_onnxruntime_version 1.12.0)
set(_acts_root_version 6.20)
set(_acts_tbb_version 2020.1)

Expand Down Expand Up @@ -288,7 +289,7 @@ if(ACTS_BUILD_PLUGIN_JSON)
endif()
endif()
if(ACTS_BUILD_PLUGIN_ONNX)
find_package(OnnxRuntime REQUIRED)
find_package(OnnxRuntime ${_acts_onnxruntime_version} REQUIRED)
endif()
if(ACTS_BUILD_PLUGIN_SYCL)
find_package(SYCL REQUIRED)
Expand Down
4 changes: 3 additions & 1 deletion Plugins/Onnx/include/Acts/Plugins/Onnx/OnnxRuntimeBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ class OnnxRuntimeBase {
private:
/// ONNX runtime session / model properties
std::unique_ptr<Ort::Session> m_session;
std::vector<Ort::AllocatedStringPtr> m_inputNodeNamesAllocated;
std::vector<const char*> m_inputNodeNames;
std::vector<int64_t> m_inputNodeDims;
std::vector<Ort::AllocatedStringPtr> m_outputNodeNamesAllocated;
std::vector<const char*> m_outputNodeNames;
std::vector<int64_t> m_outputNodeDims;
};

} // namespace Acts
} // namespace Acts
10 changes: 6 additions & 4 deletions Plugins/Onnx/src/OnnxRuntimeBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ Acts::OnnxRuntimeBase::OnnxRuntimeBase(Ort::Env& env, const char* modelPath) {

// get the names of the input nodes of the model
size_t numInputNodes = m_session->GetInputCount();
m_inputNodeNames.resize(numInputNodes);

// iterate over all input nodes and get the name
for (size_t i = 0; i < numInputNodes; i++) {
m_inputNodeNames[i] = m_session->GetInputName(i, allocator);
m_inputNodeNamesAllocated.push_back(
m_session->GetInputNameAllocated(i, allocator));
m_inputNodeNames.push_back(m_inputNodeNamesAllocated.back().get());

// get the dimensions of the input nodes
// here we assume that all input nodes have the dimensions
Expand All @@ -47,11 +48,12 @@ Acts::OnnxRuntimeBase::OnnxRuntimeBase(Ort::Env& env, const char* modelPath) {

// get the names of the output nodes
size_t numOutputNodes = m_session->GetOutputCount();
m_outputNodeNames.resize(numOutputNodes);

// iterate over all output nodes and get the name
for (size_t i = 0; i < numOutputNodes; i++) {
m_outputNodeNames[i] = m_session->GetOutputName(i, allocator);
m_outputNodeNamesAllocated.push_back(
m_session->GetOutputNameAllocated(i, allocator));
m_outputNodeNames.push_back(m_outputNodeNamesAllocated.back().get());

// get the dimensions of the output nodes
// here we assume that all output nodes have the dimensions
Expand Down
12 changes: 11 additions & 1 deletion cmake/FindOnnxRuntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ find_path(
if(NOT OnnxRuntime_INCLUDE_DIR)
message(FATAL_ERROR "onnxruntime includes not found")
else()
message(STATUS "Found OnnxRuntime includes at ${OnnxRuntime_INCLUDE_DIR}")
file(READ ${OnnxRuntime_INCLUDE_DIR}/core/session/onnxruntime_c_api.h ver)
string(REGEX MATCH "ORT_API_VERSION ([0-9]*)" _ ${ver})
set(OnnxRuntime_API_VERSION ${CMAKE_MATCH_1})
message(STATUS "Found OnnxRuntime includes at ${OnnxRuntime_INCLUDE_DIR} (API version: ${OnnxRuntime_API_VERSION})")
endif()


string(REPLACE "." ";" OnnxRuntime_MIN_VERSION_LIST ${_acts_onnxruntime_version})
list(GET OnnxRuntime_MIN_VERSION_LIST 1 OnnxRuntime_MIN_API_VERSION)
if("${OnnxRuntime_API_VERSION}" LESS ${OnnxRuntime_MIN_API_VERSION})
message(FATAL_ERROR "OnnxRuntime API version ${OnnxRuntime_MIN_API_VERSION} or greater required")
endif()

include(FindPackageHandleStandardArgs)
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ components:
- [Geant4](http://geant4.org/) for some examples
- [HepMC](https://gitlab.cern.ch/hepmc/HepMC3) >= 3.2.1 for some examples
- [Intel Threading Building Blocks](https://01.org/tbb) >= 2020.1 for the examples
- [ONNX Runtime](https://onnxruntime.ai/) for the ONNX plugin, the Exa.TrkX plugin and some examples
- [ONNX Runtime](https://onnxruntime.ai/) >= 1.12.0 for the ONNX plugin, the Exa.TrkX plugin and some examples
- [Pythia8](https://pythia.org) for some examples
- [ROOT](https://root.cern.ch) >= 6.20 for the TGeo plugin and the examples
- [Sphinx](https://www.sphinx-doc.org) >= 2.0 with [Breathe](https://breathe.readthedocs.io/en/latest/), [Exhale](https://exhale.readthedocs.io/en/latest/), and [recommonmark](https://recommonmark.readthedocs.io/en/latest/index.html) extensions for the documentation
Expand Down

0 comments on commit c2ef676

Please sign in to comment.