Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunAsync C/CXX API #16613

Merged
merged 30 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,18 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

/** \brief Callback function for RunAsync
*
* \param[in] user_data A customized handle passed in by RunAsync.
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] outputs On succeed, outputs host inference results.
NOTE:
1. Ort is in charge of the lifetime of "outputs" array, but NOT each of its element, which is a OrtValue*.
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
2. Customer is expected to release each element of "outputs", which is a OrtValue*.
* \param[out] num_outputs Number of output.
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] status On error, status will provide details.
*/
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
Expand Down Expand Up @@ -4316,6 +4328,28 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* \param[in] session
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] inputs Array of ::OrtValue%s of the input values
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be
* an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers
* to them will be set into the `outputs` array.
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* inputs, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};

/*
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <unordered_map>
#include <utility>
#include <type_traits>
#include <functional>
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
Expand Down Expand Up @@ -781,6 +782,13 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

/** \brief Run the model in a separate thread.
* Callback will be invoked on run completion, with output values as arguments,
* on error, a status could be returned.
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);

RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,14 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values, callback, user_data));
}

template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;
Expand Down
107 changes: 107 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,113 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}

Status InferenceSession::Run(const OrtRunOptions* run_options,
const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** output) {
InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_name_vec;
output_name_vec.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names[i] == nullptr || output_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
output_name_vec.emplace_back(output_names[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = Run(op, feed_names, feeds, output_name_vec, &fetches, nullptr);
} else {
status = Run(*run_options, feed_names, feeds, output_name_vec, &fetches, nullptr);
}

if (!status.IsOK())
return status;

// We do it in two loops to make sure copy __ctors does not throw
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}

ORT_ENFORCE(output_unique_ptrs.size() == output_names_len);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
ORT_ENFORCE(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return Status::OK();
}

Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* inputs, size_t input_len,
const char* const* output_name, size_t output_names_len,
OrtValue** outputs, RunAsyncCallbackFn callback, void* user_data) {
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

InferenceSession* sess = this;
std::function<void()> run_fn = [=]() {
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
ORT_TRY {
auto status = sess->Run(run_options, input_names, inputs, input_len, output_name, output_names_len, outputs);
if (status.IsOK()) {
callback(user_data, outputs, output_names_len, {});
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& e) {
std::string what = "unknown exception";
ORT_HANDLE_EXCEPTION([&]() { what = e.what(); });
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, what.c_str())));
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
}
};
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
return Status::OK();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);

[[nodiscard]] common::Status Run(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** output);

[[nodiscard]] common::Status RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
const OrtValue* const* inputs, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** outputs, RunAsyncCallbackFn callback, void* user_data = nullptr);

/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
96 changes: 29 additions & 67 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,81 +817,41 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
auto status = session->Run(run_options, input_names, input, input_len, output_names, output_names_len, output);
if (status.IsOK()) {
return nullptr;
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
API_IMPL_END
}

assert(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* inputs, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto status = session->RunAsync(run_options,
input_names,
inputs,
input_len,
output_names,
output_names_len,
outputs,
run_async_callback,
user_data);
if (status.IsOK()) {
return nullptr;
} else {
return ToOrtStatus(status);
}
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
API_IMPL_END
}

Expand Down Expand Up @@ -2735,6 +2695,8 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,

&OrtApis::RunAsync,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions

ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis
Loading