Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunAsync C/CXX API #16613

Merged
merged 30 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

typedef void (*RunAsyncCallbackFn)(OrtValue*, size_t, OrtStatusPtr);
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved

/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
Expand Down Expand Up @@ -4316,6 +4318,12 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* inputs, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback);
};

/*
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,9 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackFn callback);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,13 @@ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, con
return output_values;
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackFn callback) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
ThrowOnError(GetApi().RunAsync(this->p_, &run_options, input_names, ort_input_values, input_count, output_names, output_count, callback));
}

template <typename T>
inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count) {
Expand Down
95 changes: 95 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,101 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}

Status InferenceSession::Run(const OrtRunOptions* run_options,
const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names1, size_t output_names_len,
OrtValue** output) {

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = Run(op, feed_names, feeds, output_names, &fetches, nullptr);
} else {
status = Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}

if (!status.IsOK())
return status;

// We do it in two loops to make sure copy __ctors does not throw
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}

ORT_ENFORCE(output_unique_ptrs.size() == output_names_len);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
ORT_ENFORCE(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return Status::OK();
}

Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_name, size_t output_names_len,
RunAsyncCallbackFn callback) {
InferenceSession* sess = this;
std::function<void()> run_fn = [&]() {
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
OrtValue* outputs{};
auto status = sess->Run(run_options, input_names, input, input_len, output_name, output_names_len, &outputs);
if (status.IsOK()) {
callback(outputs, 0, {});
} else {
callback({}, output_names_len, ToOrtStatus(status));
}
};
concurrency::ThreadPool::Schedule(inter_op_thread_pool_.get(), run_fn);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
return Status::OK();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);

[[nodiscard]] common::Status Run(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names1, size_t output_names_len,
OrtValue** output);

[[nodiscard]] common::Status RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names1, size_t output_names_len,
RunAsyncCallbackFn callback);

/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
87 changes: 20 additions & 67 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,81 +817,32 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
auto status = session->Run(run_options, input_names, input, input_len, output_names, output_names_len, output);
if (status.IsOK()) {
return nullptr;
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
API_IMPL_END
}

assert(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto status = session->RunAsync(run_options, input_names, input, input_len, output_names, output_names_len, run_async_callback);
if (status.IsOK()) {
return nullptr;
} else {
return ToOrtStatus(status);
}
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
API_IMPL_END
}

Expand Down Expand Up @@ -2735,6 +2686,8 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,

&OrtApis::RunAsync,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,10 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions

ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback);
} // namespace OrtApis