diff --git a/Plugins/Onnx/include/Acts/Plugins/Onnx/OnnxRuntimeBase.hpp b/Plugins/Onnx/include/Acts/Plugins/Onnx/OnnxRuntimeBase.hpp index 0e57a930df3..8a61e97d680 100644 --- a/Plugins/Onnx/include/Acts/Plugins/Onnx/OnnxRuntimeBase.hpp +++ b/Plugins/Onnx/include/Acts/Plugins/Onnx/OnnxRuntimeBase.hpp @@ -7,13 +7,16 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. #pragma once - #include +#include #include namespace Acts { +using NetworkBatchInput = + Eigen::Array; + // General class that sets up the ONNX runtime framework for loading a ML model // and using it for inference. class OnnxRuntimeBase { @@ -30,7 +33,6 @@ class OnnxRuntimeBase { /// @brief Default destructor ~OnnxRuntimeBase() = default; - protected: /// @brief Run the ONNX inference function /// /// @param inputTensorValues The input feature values used for prediction @@ -39,6 +41,14 @@ class OnnxRuntimeBase { std::vector runONNXInference( std::vector& inputTensorValues) const; + /// @brief Run the ONNX inference function for a batch of input + /// + /// @param inputTensorValues Vector of the input feature values of all the inputs used for prediction + /// + /// @return The vector of output (predicted) values + std::vector> runONNXInference( + NetworkBatchInput& inputTensorValues) const; + private: /// ONNX runtime session / model properties std::unique_ptr m_session; diff --git a/Plugins/Onnx/src/OnnxRuntimeBase.cpp b/Plugins/Onnx/src/OnnxRuntimeBase.cpp index 3f64eec2b01..8b2c6927338 100644 --- a/Plugins/Onnx/src/OnnxRuntimeBase.cpp +++ b/Plugins/Onnx/src/OnnxRuntimeBase.cpp @@ -11,102 +11,120 @@ #include #include -// parametrized constructor +// Parametrized constructor Acts::OnnxRuntimeBase::OnnxRuntimeBase(Ort::Env& env, const char* modelPath) { - // set the ONNX runtime session options + // Set the ONNX runtime session options Ort::SessionOptions sessionOptions; - // set graph optimization level + // Set graph optimization level sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_BASIC); - // create the Ort session + // Create the Ort session m_session = std::make_unique(env, modelPath, sessionOptions); - - // default allocator + // Default allocator Ort::AllocatorWithDefaultOptions allocator; - // get the names of the input nodes of the model + // Get the names of the input nodes of the model size_t numInputNodes = m_session->GetInputCount(); - - // iterate over all input nodes and get the name + // Iterate over all input nodes and get the name for (size_t i = 0; i < numInputNodes; i++) { m_inputNodeNamesAllocated.push_back( m_session->GetInputNameAllocated(i, allocator)); m_inputNodeNames.push_back(m_inputNodeNamesAllocated.back().get()); - // get the dimensions of the input nodes - // here we assume that all input nodes have the dimensions + // Get the dimensions of the input nodes, + // here we assume that all input nodes have the same dimensions Ort::TypeInfo inputTypeInfo = m_session->GetInputTypeInfo(i); auto tensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo(); m_inputNodeDims = tensorInfo.GetShape(); - // fix for symbolic dim = -1 from python - for (size_t j = 0; j < m_inputNodeDims.size(); j++) { - if (m_inputNodeDims[j] < 0) { - m_inputNodeDims[j] = 1; - } - } } - // get the names of the output nodes + // Get the names of the output nodes size_t numOutputNodes = m_session->GetOutputCount(); - - // iterate over all output nodes and get the name + // Iterate over all output nodes and get the name for (size_t i = 0; i < numOutputNodes; i++) { m_outputNodeNamesAllocated.push_back( m_session->GetOutputNameAllocated(i, allocator)); m_outputNodeNames.push_back(m_outputNodeNamesAllocated.back().get()); - // get the dimensions of the output nodes + // Get the dimensions of the output nodes // here we assume that all output nodes have the dimensions Ort::TypeInfo outputTypeInfo = m_session->GetOutputTypeInfo(i); auto tensorInfo = outputTypeInfo.GetTensorTypeAndShapeInfo(); m_outputNodeDims = tensorInfo.GetShape(); - // fix for symbolic dim = -1 from python - for (size_t j = 0; j < m_outputNodeDims.size(); j++) { - if (m_outputNodeDims[j] < 0) { - m_outputNodeDims[j] = 1; - } - } } } -// inference function using ONNX runtime -// the function assumes that the model has 1 input node and 1 output node +// Inference function using ONNX runtime for one single entry std::vector Acts::OnnxRuntimeBase::runONNXInference( std::vector& inputTensorValues) const { - // create input tensor object from data values + Acts::NetworkBatchInput vectorInput(1, inputTensorValues.size()); + for (size_t i = 0; i < inputTensorValues.size(); i++) { + vectorInput(0, i) = inputTensorValues[i]; + } + auto vectorOutput = runONNXInference(vectorInput); + return vectorOutput[0]; +} + +// Inference function using ONNX runtime +// the function assumes that the model has 1 input node and 1 output node +std::vector> Acts::OnnxRuntimeBase::runONNXInference( + Acts::NetworkBatchInput& inputTensorValues) const { + int batchSize = inputTensorValues.rows(); + std::vector inputNodeDims = m_inputNodeDims; + std::vector outputNodeDims = m_outputNodeDims; + + // The first dim node should correspond to the batch size + // If it is -1, it is dynamic and should be set to the input size + if (inputNodeDims[0] == -1) { + inputNodeDims[0] = batchSize; + } + if (outputNodeDims[0] == -1) { + outputNodeDims[0] = batchSize; + } + + if (batchSize != 1 && + (inputNodeDims[0] != batchSize || outputNodeDims[0] != batchSize)) { + throw std::runtime_error( + "runONNXInference: batch size doesn't match the input or output node " + "size"); + } + + // Create input tensor object from data values // note: this assumes the model has only 1 input node Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); Ort::Value inputTensor = Ort::Value::CreateTensor( memoryInfo, inputTensorValues.data(), inputTensorValues.size(), - m_inputNodeDims.data(), m_inputNodeDims.size()); - // double-check that inputTensor is a Tensor + inputNodeDims.data(), inputNodeDims.size()); + // Double-check that inputTensor is a Tensor if (!inputTensor.IsTensor()) { throw std::runtime_error( "runONNXInference: conversion of input to Tensor failed. "); } - - // score model on input tensors, get back output tensors + // Score model on input tensors, get back output tensors + Ort::RunOptions run_options; std::vector outputTensors = - m_session->Run(Ort::RunOptions{nullptr}, m_inputNodeNames.data(), - &inputTensor, m_inputNodeNames.size(), - m_outputNodeNames.data(), m_outputNodeNames.size()); - // double-check that outputTensors contains Tensors and that the count matches + m_session->Run(run_options, m_inputNodeNames.data(), &inputTensor, + m_inputNodeNames.size(), m_outputNodeNames.data(), + m_outputNodeNames.size()); + // Double-check that outputTensors contains Tensors and that the count matches // that of output nodes if (!outputTensors[0].IsTensor() || (outputTensors.size() != m_outputNodeNames.size())) { throw std::runtime_error( "runONNXInference: calculation of output failed. "); } - - // get pointer to output tensor float values + // Get pointer to output tensor float values // note: this assumes the model has only 1 output node float* outputTensor = outputTensors.front().GetTensorMutableData(); - - // get the output values - std::vector outputTensorValues(m_outputNodeDims[1]); - for (size_t i = 0; i < outputTensorValues.size(); i++) { - outputTensorValues[i] = outputTensor[i]; + // Get the output values + std::vector> outputTensorValues( + batchSize, std::vector(outputNodeDims[1], -1)); + for (int i = 0; i < outputNodeDims[0]; i++) { + for (int j = 0; j < ((outputNodeDims.size() > 1) ? outputNodeDims[1] : 1); + j++) { + outputTensorValues[i][j] = outputTensor[i * outputNodeDims[1] + j]; + } } return outputTensorValues; }