diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 5e3002aec424d..c6dc2bf790bb2 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -136,12 +136,20 @@ class TunableOp { ITuningContext* ctx = params->TuningContext(); if (ctx->IsTunableOpEnabled()) { auto& mgr = ctx->GetTuningResultsManager(); - id = mgr.Lookup(Signature(), params->Signature()); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + id = mgr.Lookup(op_sig, params_sig); + if (id > static_cast(ops_.size())) { + LOGS_DEFAULT(FATAL) << "Invalid TunableOp kernel id for " << op_sig + << ", id:" << id << ", registered op:" << ops_.size(); + mgr.Delete(op_sig, params_sig); + id = -1; + } if (id < 0) { auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); - mgr.Add(Signature(), params->Signature(), id); + mgr.Add(op_sig, params_sig, id); } } ORT_RETURN_IF_ERROR(ops_[id](params)); diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index d4aa2d0b412b7..ff93f0a224c95 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -17,6 +17,7 @@ class TuningResultsValidator; class ITuningContext { public: + explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {} virtual ~ITuningContext() = default; virtual void EnableTunableOp() = 0; @@ -27,6 +28,12 @@ class ITuningContext { virtual const TuningResultsManager& GetTuningResultsManager() const = 0; virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0; + + virtual TuningResults SaveTuningResults() const; + virtual Status LoadTuningResults(const TuningResults& tr); + + protected: + IExecutionProvider* ep_; }; class TuningResultsManager { @@ -38,6 +45,7 @@ class TuningResultsManager { int Lookup(const std::string& op_signature, const std::string& params_signature) const; void Add(const std::string& op_signature, const std::string& params_signature, int best_id); + void Delete(const std::string& op_signature, const std::string& params_signature); void Load(const std::unordered_map& results_to_load); std::unordered_map Dump() const; diff --git a/onnxruntime/core/framework/tuning_context_impl.h b/onnxruntime/core/framework/tuning_context_impl.h index 7fc124adddb9f..8cc3bfe6af316 100644 --- a/onnxruntime/core/framework/tuning_context_impl.h +++ b/onnxruntime/core/framework/tuning_context_impl.h @@ -21,6 +21,22 @@ namespace onnxruntime { +TuningResults ITuningContext::SaveTuningResults() const { + TuningResults tr; + tr.ep = ep_->Type(); + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +Status ITuningContext::LoadTuningResults(const TuningResults& tr) { + ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch"); + LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep; + ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return Status::OK(); +} + KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const { std::scoped_lock l{lock_}; auto it = results_.find(op_signature); @@ -74,6 +90,22 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin AddImpl(op_signature, params_signature, best_id, it->second); } +// NOLINTNEXTLINE(bugprone-easily-swappable-parameters) +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + it->second.erase(it2); +} + std::unordered_map TuningResultsManager::Dump() const { std::scoped_lock l{lock_}; return results_; @@ -95,6 +127,11 @@ void MergeImpl( } void TuningResultsManager::Load(const std::unordered_map& results_to_load) { + for(const auto& [op_sig, kernel_map]: results_to_load) { + for(const auto& [param_sig, kernel_id] : kernel_map) { + LOGS_DEFAULT(VERBOSE) << op_sig << " " << param_sig << " " << kernel_id; + } + } std::scoped_lock l{lock_}; for (const auto& [op_signature, kernel_map] : results_to_load) { MergeImpl(op_signature, kernel_map, results_); diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index aca418f2d7f8c..e2a4bd694dd49 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -46,7 +46,8 @@ CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep [this](const std::string& value) { return ValidateDeviceModel(value); }); } -CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} +CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index 7b7c855c00715..b2a8134c708c5 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -72,7 +72,8 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { return oss.str(); } -RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {} +RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) + : ITuningContext(ep), info_(info), validator_(ep) {} void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 43ccfa29624d9..748b6339703db 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1550,6 +1550,16 @@ common::Status InferenceSession::Initialize() { } } + std::vector tuning_results; + ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata(model_metadata_, tuning_results)); + if(!tuning_results.empty()) { + ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results)); + } + else { + LOGS(*session_logger_, WARNING) << "Got empty tuning results."; + } + + return status; } #if defined(_MSC_VER) && !defined(__clang__) @@ -2181,6 +2191,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const { return session_profiler_; } +#if !defined(ORT_MINIMAL_BUILD) +std::vector InferenceSession::GetTuningResults() const { + std::vector ret; + for (const auto& provider : execution_providers_) { + const auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx != nullptr) { + ret.emplace_back(tuning_ctx->SaveTuningResults()); + } + } + return ret; +} + +Status InferenceSession::SetTuningResults(const std::vector& trs, bool error_on_invalid) { + std::string msg; + + for (size_t i = 0; i < trs.size(); i++) { + const auto& tr = trs[i]; + auto* provider = execution_providers_.Get(tr.ep); + if (provider == nullptr) { + msg = MakeString("Cannot find execution provider ", tr.ep); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + continue; + } + + auto* tuning_ctx = provider->GetTuningContext(); + if (tuning_ctx == nullptr) { + msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp."); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + continue; + } + + auto status = tuning_ctx->LoadTuningResults(tr); + if (!status.IsOK()) { + msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); + LOGS(*session_logger_, WARNING) << msg; + ORT_RETURN_IF(error_on_invalid, msg); + } + } + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const { return session_state_->GetAllocator(mem_info); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f01523c923385..95b0dde281cda 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" +#include "core/framework/tuning_results.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" @@ -448,6 +449,23 @@ class InferenceSession { */ const profiling::Profiler& GetProfiling() const; +#if !defined(ORT_MINIMAL_BUILD) + /** + * Get the TuningResults of TunableOp for every execution providers. + * @return The TuningResults of each execution provider. + */ + std::vector GetTuningResults() const; + + /** + * Set the TuningResults back to each execution provider. Mainly for offline tuning. + * @param trs is the list of TuningResults to be loaded. + * @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced. + * @return OK if success. + */ + Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false); +#endif + + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) MemoryProfiler& GetMemoryProfiler() { return memory_profiler_; diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index d938228c3af6f..5033f1ed90bef 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options, return Status::OK(); } +// This function is called by nlohmann/json +void from_json(const json& j, TuningResults& trs) { + j.at("ep").get_to(trs.ep); + j.at("results").get_to(trs.results); + j.at("validators").get_to(trs.validators); +} + //--------------------------------------------------- //--- end of session options related helpers --- //--------------------------------------------------- @@ -227,6 +234,29 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options "Parsing RunOptions from ModelProto is not supported yet"); } +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results) { + results.clear(); + auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); + if (it == metadata.custom_metadata_map.end()) { + return Status::OK(); + } + + LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + + ORT_TRY { + auto parsed_tuning_results_json = json::parse(it->second); + results = parsed_tuning_results_json.get>(); + } + ORT_CATCH(const std::exception& e) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, + "Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring..."); + } + + return Status::OK(); +} + } // namespace inference_session_utils } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 9d63f92663ecc..168c20cfa3792 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -44,8 +44,6 @@ class JsonConfigParser { Status ParseRunOptionsFromModelProto(/*out*/ RunOptions& run_options); - Status ParseTuningResultsFromModelProto(/*out*/ std::vector& results); - private: // Logger instance that will be used to log events along the parsing steps const logging::Logger& logger_; @@ -60,6 +58,9 @@ class JsonConfigParser { bool is_ort_config_json_available_ = false; }; +Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, + std::vector& results); + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace inference_session_utils diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index f733c13b6d085..5108f3245213a 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -286,6 +286,12 @@ def run_with_iobinding(self, iobinding, run_options=None): """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) + def get_tuning_results(self): + return self._sess.get_tuning_results() + + def set_tuning_results(self, results, error_on_invalid=False): + return self._sess.set_tuning_results(results, error_on_invalid) + def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices): """ Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7d4fb6d32c3de..ceb39882eb4d3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1666,6 +1666,45 @@ including arg name, arg type (contains both type and shape).)pbdoc") status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + }) + .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { + py::list ret; + for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) { + py::dict py_trs; + py_trs["ep"] = trs.ep; + py_trs["results"] = trs.results; + py_trs["validators"] = trs.validators; + ret.append(std::move(py_trs)); + } + + return ret; + }) + .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { + std::vector tuning_results; + for (auto handle: results) { + auto py_trs = handle.cast(); + TuningResults trs; + trs.ep = py_trs["ep"].cast(); + + for (const auto& [py_op_sig, py_kernel_map]: py_trs["results"].cast()) { + KernelMap kernel_map; + for (const auto& [py_params_sig, py_kernel_id]: py_kernel_map.cast()) { + kernel_map[py_params_sig.cast()] = py_kernel_id.cast(); + } + trs.results[py_op_sig.cast()] = kernel_map; + } + + for (const auto& [k, v]: py_trs["validators"].cast()) { + trs.validators[k.cast()] = v.cast(); + } + + tuning_results.emplace_back(std::move(trs)); + } + + Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid); + if (!status.IsOK()) { + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + } }); py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index f8a94ec59f55c..2da867b96218e 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -32,6 +32,8 @@ class TestTuningResultsValidator : public TuningResultsValidator { class TestTuningContext : public ITuningContext { public: + using ITuningContext::ITuningContext; + void EnableTunableOp() override { tuning_enabled_ = true; } void DisableTunableOp() override { tuning_enabled_ = false; } bool IsTunableOpEnabled() const override { return tuning_enabled_; } @@ -41,6 +43,8 @@ class TestTuningContext : public ITuningContext { const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; } + void ClearCache() { manager_.Clear(); } + private: bool tuning_enabled_{false}; TuningResultsManager manager_{}; @@ -49,7 +53,7 @@ class TestTuningContext : public ITuningContext { class TestEP : public IExecutionProvider { static constexpr const char* kEPType = "TestEP"; - TestTuningContext tuning_ctx_{}; + TestTuningContext tuning_ctx_{this}; public: TestEP() : IExecutionProvider{kEPType, true} {} @@ -58,6 +62,7 @@ class TestEP : public IExecutionProvider { return const_cast(&tuning_ctx_); } + void ClearCache() { tuning_ctx_.ClearCache(); } }; class TestTimer : public ITimer {