From e1ca8ee6d4a188f01efd4083d6f4273a3ef901aa Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Sun, 16 Jul 2023 16:51:40 -0700 Subject: [PATCH] RunAsync C/CXX API (#16613) Implement RunAsync API - the session will run in a thread of intra-op thread pool. --------- Co-authored-by: Randy Shuai --- .../core/session/onnxruntime_c_api.h | 30 +++++ .../core/session/onnxruntime_cxx_api.h | 18 +++ .../core/session/onnxruntime_cxx_inline.h | 10 ++ onnxruntime/core/session/inference_session.cc | 110 ++++++++++++++++++ onnxruntime/core/session/inference_session.h | 14 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 104 +++++++---------- onnxruntime/core/session/ort_apis.h | 7 ++ onnxruntime/test/shared_lib/test_inference.cc | 81 ++++++++++++- 8 files changed, 308 insertions(+), 66 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c87739b511de4..12e68630eaf48 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); +/** \brief Callback function for RunAsync + * + * \param[in] user_data User specific data that passed back to the callback + * \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr + * \param[out] num_outputs Number of outputs, on error, the value will be zero + * \param[out] status On error, status will provide details + */ +typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); + /** \brief The C API * * All C API functions are defined inside this structure as pointers to functions. @@ -4316,6 +4325,27 @@ struct OrtApi { */ ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr + * The array will be passed back to run_async_callback + * \param[in] run_async_callback Callback function on model run completion + * \param[in] user_data User data that pass back to run_async_callback + */ + ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output, + _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index dbd4ab0012ceb..b629e1411ecd1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1067,6 +1067,24 @@ struct SessionImpl : ConstSessionImpl { void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * Wraps OrtApi::RunAsync + * + * \param[in] run_options + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input_values Array of ::OrtValue%s of the input values + * \param[in] input_count Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr + * The array will be passed back to the callback + * \param[in] output_count Number of elements in the output_names and outputs array + * \param[in] callback Callback function on model run completion + * \param[in] user_data User data that pass back to the callback + */ + void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); + /** \brief End profiling and return a copy of the profiling file name. * * \param allocator to allocate memory for the copy of the string returned diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 7000e4823eee2..22172832cde8e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -972,6 +972,16 @@ inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); } +template +inline void SessionImpl::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) { + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, + ort_input_values, input_count, output_names, output_count, + ort_output_values, callback, user_data)); +} + template inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { char* out = nullptr; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cf0fd41ab8f01..f88bf8cf2468a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2300,6 +2300,116 @@ Status InferenceSession::Run(const RunOptions& run_options, return retval; } +Status InferenceSession::Run(const RunOptions& run_options, + gsl::span feed_names, + gsl::span feeds, + gsl::span fetch_names, + gsl::span fetches) { + size_t num_feeds = feed_names.size(); + size_t num_fetches = fetch_names.size(); + InlinedVector feed_name_vec; + feed_name_vec.reserve(num_feeds); + InlinedVector feed_vec; + feed_vec.reserve(num_feeds); + + for (size_t i = 0; i != num_feeds; ++i) { + if (feed_names[i] == nullptr || feed_names[i][0] == '\0') { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty"); + } + + if (!feeds[i]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str()); + } + + feed_name_vec.emplace_back(feed_names[i]); + feed_vec.emplace_back(*feeds[i]); + } + + // Create output feed + InlinedVector fetch_name_vec; + fetch_name_vec.reserve(num_fetches); + for (size_t i = 0; i != num_fetches; ++i) { + if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty"); + } + fetch_name_vec.emplace_back(fetch_names[i]); + } + + std::vector fetch_vec; + fetch_vec.reserve(num_fetches); + for (size_t i = 0; i != num_fetches; ++i) { + if (fetches[i] != nullptr) { + fetch_vec.emplace_back(*fetches[i]); + } else { + fetch_vec.emplace_back(); + } + } + + Status status; + status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr); + + if (!status.IsOK()) + return status; + + // We do it in two loops to make sure copy __ctors does not throw + InlinedVector> fetch_unique_ptrs; + fetch_unique_ptrs.reserve(num_fetches); + for (size_t i = 0; i != num_fetches; ++i) { + if (fetches[i] == nullptr) { + fetch_unique_ptrs.emplace_back(std::make_unique(fetch_vec[i])); + } else { + fetch_unique_ptrs.emplace_back(); + } + } + + for (size_t i = 0; i != num_fetches; ++i) { + if (fetches[i] == nullptr) { + ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr); + fetches[i] = fetch_unique_ptrs[i].release(); + } + } + return Status::OK(); +} + +common::Status InferenceSession::RunAsync(const RunOptions* run_options, + gsl::span feed_names, + gsl::span feeds, + gsl::span fetch_names, + gsl::span fetches, + RunAsyncCallbackFn callback, + void* user_data) { + size_t num_fetches = fetch_names.size(); + if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync"); + } + std::function run_fn = [=]() { + ORT_TRY { + Status status; + if (run_options) { + status = Run(*run_options, feed_names, feeds, fetch_names, fetches); + } else { + RunOptions default_run_options; + status = Run(default_run_options, feed_names, feeds, fetch_names, fetches); + } + if (status.IsOK()) { + callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status)); + } else { + callback(user_data, {}, 0, ToOrtStatus(status)); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([=]() { + callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()))); + }); + } + ORT_CATCH(...) { + callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception"))); + } + }; // run_fn + concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn); + return Status::OK(); +} + common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span output_names, std::vector* p_fetches) { return Run(RunOptions(), feeds, output_names, p_fetches); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index fd830b3c8d7f9..e4127085b3184 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -305,6 +305,20 @@ class InferenceSession { std::vector* p_fetches, const std::vector* p_fetches_device_info = nullptr); + [[nodiscard]] common::Status Run(const RunOptions& run_options, + gsl::span feed_names, + gsl::span feeds, + gsl::span fetch_names, + gsl::span fetches); + + [[nodiscard]] common::Status RunAsync(const RunOptions* run_options, + gsl::span feed_names, + gsl::span feeds, + gsl::span fetch_names, + gsl::span fetches, + RunAsyncCallbackFn callback, + void* user_data = nullptr); + /** * Run a pre-loaded and pre-intialized model. * Multiple threads are allowed to run this function; hence its thread-safe. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 94fbf6c1deb22..e6e17390ddc4b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, _In_reads_(input_len) const char* const* input_names, _In_reads_(input_len) const OrtValue* const* input, size_t input_len, - _In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, _Inout_updates_all_(output_names_len) OrtValue** output) { API_IMPL_BEGIN auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); - InlinedVector feed_names; - feed_names.reserve(input_len); - InlinedVector feeds; - feeds.reserve(input_len); - - for (size_t i = 0; i != input_len; ++i) { - if (input_names[i] == nullptr || input_names[i][0] == '\0') { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty"); - } - - if (!input[i]) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - MakeString("NULL input supplied for input ", input_names[i]).c_str()); - } - - feed_names.emplace_back(input_names[i]); - feeds.emplace_back(*input[i]); - } - - // Create output feed - InlinedVector output_names; - output_names.reserve(output_names_len); - for (size_t i = 0; i != output_names_len; ++i) { - if (output_names1[i] == nullptr || output_names1[i][0] == '\0') { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty"); - } - output_names.emplace_back(output_names1[i]); - } - - std::vector fetches; - fetches.reserve(output_names_len); - for (size_t i = 0; i != output_names_len; ++i) { - if (output[i] != nullptr) { - fetches.emplace_back(*output[i]); - } else { - fetches.emplace_back(); - } - } + gsl::span input_names_span(input_names, input_len); + gsl::span input_span(input, input_len); + gsl::span output_name_span(output_names, output_names_len); + gsl::span output_span(output, output_names_len); Status status; - if (run_options == nullptr) { - OrtRunOptions op; - status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr); + if (run_options) { + status = session->Run(*run_options, + input_names_span, + input_span, + output_name_span, + output_span); } else { - status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr); + const RunOptions default_run_options; + status = session->Run(default_run_options, + input_names_span, + input_span, + output_name_span, + output_span); } + return ToOrtStatus(status); + API_IMPL_END +} - if (!status.IsOK()) - return ToOrtStatus(status); - - // We do it in two loops to make sure copy __ctors does not throw - InlinedVector> output_unique_ptrs; - output_unique_ptrs.reserve(output_names_len); - for (size_t i = 0; i != output_names_len; ++i) { - if (output[i] == nullptr) { - output_unique_ptrs.emplace_back(std::make_unique(fetches[i])); - } else { - output_unique_ptrs.emplace_back(); - } - } +ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output, + _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) { + API_IMPL_BEGIN + auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); - assert(output_unique_ptrs.size() == output_names_len); + gsl::span input_names_span(input_names, input_len); + gsl::span input_span(input, input_len); + gsl::span output_name_span(output_names, output_names_len); + gsl::span output_span(output, output_names_len); - for (size_t i = 0; i != output_names_len; ++i) { - if (output[i] == nullptr) { - assert(output_unique_ptrs[i] != nullptr); - output[i] = output_unique_ptrs[i].release(); - } - } - return nullptr; + return ToOrtStatus(session->RunAsync(run_options, + input_names_span, + input_span, + output_name_span, + output_span, + run_async_callback, + user_data)); API_IMPL_END } @@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = { &OrtApis::GetROCMProviderOptionsAsString, &OrtApis::ReleaseROCMProviderOptions, &OrtApis::CreateAndRegisterAllocatorV2, + &OrtApis::RunAsync, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 2ea04f6acf367..6a18581bf6dad 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** outputs, + _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8eb1152f1f612..f0457facd8d85 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3267,8 +3267,8 @@ TEST(MultiKernelSingleSchemaTest, valid) { Ort::Value::CreateTensor(memory_info, x_value, 10, x_dim, 1), }; - Ort::RunOptions run_optoins; - auto output_tensors = session.Run(run_optoins, input_names, input_tensors, 1, output_names, 2); + Ort::RunOptions run_options; + auto output_tensors = session.Run(run_options, input_names, input_tensors, 1, output_names, 2); ASSERT_TRUE(*output_tensors[1].GetTensorData() == 72); } @@ -3346,3 +3346,80 @@ TEST(MultiKernelSingleSchemaTest, DuplicateKernel) { } #endif + +static std::thread::id caller_tid = std::this_thread::get_id(); +static std::atomic_bool atomic_wait{false}; + +void CallbackSucceed(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status_ptr) { + auto callee_tid = std::this_thread::get_id(); + EXPECT_NE(*(reinterpret_cast(user_data)), callee_tid); + Ort::Status status(status_ptr); + EXPECT_TRUE(status.IsOK()); + EXPECT_EQ(num_outputs, 1UL); + Ort::Value output_value(outputs[0]); + EXPECT_EQ(output_value.At({1, 0}), 9.f); + output_value.release(); + atomic_wait.store(true); +} + +TEST(CApiTest, RunAsync) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(2); + Ort::Session session(*ort_env, MODEL_URI, session_options); + + const char* input_names[] = {"X"}; + float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t x_dim[] = {3, 2}; + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + Ort::Value input_tensors[1] = { + Ort::Value::CreateTensor(memory_info, x_value, 6, x_dim, 2), + }; + + const char* output_names[] = {"Y"}; + Ort::RunOptions run_options; + Ort::Value output_values[1] = {Ort::Value{nullptr}}; + + EXPECT_NO_THROW(session.RunAsync(run_options, + input_names, + input_tensors, + 1, + output_names, + output_values, + 1, + CallbackSucceed, + &caller_tid)); + + std::chrono::duration dur{100}; + // timeout in about 10 secs + for (int i = 0; i < 100 && !atomic_wait.load(); ++i) { + std::this_thread::sleep_for(dur); + } + + EXPECT_EQ(atomic_wait.load(), true); +} + +void CallbackFail(void*, OrtValue**, size_t, OrtStatusPtr) { + EXPECT_TRUE(false); // the callback is not supposed to be invoked +} + +TEST(CApiTest, RunAsyncFail) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); // This will cause RunAsync fail + Ort::Session session(*ort_env, MODEL_URI, session_options); + + const char* input_names[] = {"X"}; + float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t x_dim[] = {3, 2}; + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + Ort::Value input_tensors[1] = { + Ort::Value::CreateTensor(memory_info, x_value, 6, x_dim, 2), + }; + Ort::Value output_values[1] = {Ort::Value{nullptr}}; + const char* output_names[] = {"Y"}; + + Ort::RunOptions run_options; + EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception); +} \ No newline at end of file