Skip to content

Commit

Permalink
Add session API for getting and setting tuning resutls. Add embeded t…
Browse files Browse the repository at this point in the history
…uning results auto loading on session init
  • Loading branch information
cloudhan committed Feb 3, 2023
1 parent 8eae816 commit 05bb29f
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 7 deletions.
12 changes: 10 additions & 2 deletions onnxruntime/core/framework/tunable.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,20 @@ class TunableOp {
ITuningContext* ctx = params->TuningContext();
if (ctx->IsTunableOpEnabled()) {
auto& mgr = ctx->GetTuningResultsManager();
id = mgr.Lookup(Signature(), params->Signature());
auto op_sig = Signature();
auto params_sig = params->Signature();
id = mgr.Lookup(op_sig, params_sig);
if (id > static_cast<int>(ops_.size())) {
LOGS_DEFAULT(FATAL) << "Invalid TunableOp kernel id for " << op_sig
<< ", id:" << id << ", registered op:" << ops_.size();
mgr.Delete(op_sig, params_sig);
id = -1;
}
if (id < 0) {
auto maybe_proxy_params = PreTuning(params);
id = FindFastest(maybe_proxy_params);
PostTuning(maybe_proxy_params);
mgr.Add(Signature(), params->Signature(), id);
mgr.Add(op_sig, params_sig, id);
}
}
ORT_RETURN_IF_ERROR(ops_[id](params));
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/framework/tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TuningResultsValidator;

class ITuningContext {
public:
explicit ITuningContext(IExecutionProvider* ep) : ep_(ep) {}
virtual ~ITuningContext() = default;

virtual void EnableTunableOp() = 0;
Expand All @@ -27,6 +28,12 @@ class ITuningContext {
virtual const TuningResultsManager& GetTuningResultsManager() const = 0;

virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0;

virtual TuningResults SaveTuningResults() const;
virtual Status LoadTuningResults(const TuningResults& tr);

protected:
IExecutionProvider* ep_;
};

class TuningResultsManager {
Expand All @@ -38,6 +45,7 @@ class TuningResultsManager {
int Lookup(const std::string& op_signature, const std::string& params_signature) const;

void Add(const std::string& op_signature, const std::string& params_signature, int best_id);
void Delete(const std::string& op_signature, const std::string& params_signature);

void Load(const std::unordered_map<std::string, KernelMap>& results_to_load);
std::unordered_map<std::string, KernelMap> Dump() const;
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/core/framework/tuning_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@

namespace onnxruntime {

TuningResults ITuningContext::SaveTuningResults() const {
TuningResults tr;
tr.ep = ep_->Type();
tr.validators = GetTuningResultsValidator().GetAllValidators();
tr.results = GetTuningResultsManager().Dump();
return tr;
}

Status ITuningContext::LoadTuningResults(const TuningResults& tr) {
ORT_RETURN_IF(tr.ep != ep_->Type(), "EP mismatch");
LOGS_DEFAULT(VERBOSE) << "Loading tuning results for " << tr.ep;
ORT_RETURN_IF_ERROR(GetTuningResultsValidator().ValidateAll(tr.validators));
GetTuningResultsManager().Load(tr.results);
return Status::OK();
}

KernelMap TuningResultsManager::Lookup(const std::string& op_signature) const {
std::scoped_lock l{lock_};
auto it = results_.find(op_signature);
Expand Down Expand Up @@ -74,6 +90,22 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin
AddImpl(op_signature, params_signature, best_id, it->second);
}

// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};

auto it = results_.find(op_signature);
if (it == results_.end()) {
return;
}

auto it2 = it->second.find(params_signature);
if (it2 == it->second.end()) {
return;
}
it->second.erase(it2);
}

std::unordered_map<std::string, KernelMap> TuningResultsManager::Dump() const {
std::scoped_lock l{lock_};
return results_;
Expand All @@ -95,6 +127,11 @@ void MergeImpl(
}

void TuningResultsManager::Load(const std::unordered_map<std::string, KernelMap>& results_to_load) {
for(const auto& [op_sig, kernel_map]: results_to_load) {
for(const auto& [param_sig, kernel_id] : kernel_map) {
LOGS_DEFAULT(VERBOSE) << op_sig << " " << param_sig << " " << kernel_id;
}
}
std::scoped_lock l{lock_};
for (const auto& [op_signature, kernel_map] : results_to_load) {
MergeImpl(op_signature, kernel_map, results_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep
[this](const std::string& value) { return ValidateDeviceModel(value); });
}

CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {}
CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}

void CudaTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
return oss.str();
}

RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {}
RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}

void RocmTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider";
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,16 @@ common::Status InferenceSession::Initialize() {
}
}

std::vector<TuningResults> tuning_results;
ORT_RETURN_IF_ERROR(inference_session_utils::ParseTuningResultsFromModelMetadata(model_metadata_, tuning_results));
if(!tuning_results.empty()) {
ORT_RETURN_IF_ERROR(SetTuningResults(tuning_results));
}
else {
LOGS(*session_logger_, WARNING) << "Got empty tuning results.";
}


return status;
}
#if defined(_MSC_VER) && !defined(__clang__)
Expand Down Expand Up @@ -2181,6 +2191,50 @@ const profiling::Profiler& InferenceSession::GetProfiling() const {
return session_profiler_;
}

#if !defined(ORT_MINIMAL_BUILD)
std::vector<TuningResults> InferenceSession::GetTuningResults() const {
std::vector<TuningResults> ret;
for (const auto& provider : execution_providers_) {
const auto* tuning_ctx = provider->GetTuningContext();
if (tuning_ctx != nullptr) {
ret.emplace_back(tuning_ctx->SaveTuningResults());
}
}
return ret;
}

Status InferenceSession::SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid) {
std::string msg;

for (size_t i = 0; i < trs.size(); i++) {
const auto& tr = trs[i];
auto* provider = execution_providers_.Get(tr.ep);
if (provider == nullptr) {
msg = MakeString("Cannot find execution provider ", tr.ep);
LOGS(*session_logger_, WARNING) << msg;
ORT_RETURN_IF(error_on_invalid, msg);
continue;
}

auto* tuning_ctx = provider->GetTuningContext();
if (tuning_ctx == nullptr) {
msg = MakeString("Invalid TuningResults (index=", i, "). ", tr.ep, " does not support TunableOp.");
LOGS(*session_logger_, WARNING) << msg;
ORT_RETURN_IF(error_on_invalid, msg);
continue;
}

auto status = tuning_ctx->LoadTuningResults(tr);
if (!status.IsOK()) {
msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage());
LOGS(*session_logger_, WARNING) << msg;
ORT_RETURN_IF(error_on_invalid, msg);
}
}
return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)

AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const {
return session_state_->GetAllocator(mem_info);
}
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/prepacked_weights_container.h"
#include "core/framework/session_state.h"
#include "core/framework/tuning_results.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
Expand Down Expand Up @@ -448,6 +449,23 @@ class InferenceSession {
*/
const profiling::Profiler& GetProfiling() const;

#if !defined(ORT_MINIMAL_BUILD)
/**
* Get the TuningResults of TunableOp for every execution providers.
* @return The TuningResults of each execution provider.
*/
std::vector<TuningResults> GetTuningResults() const;

/**
* Set the TuningResults back to each execution provider. Mainly for offline tuning.
* @param trs is the list of TuningResults to be loaded.
* @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced.
* @return OK if success.
*/
Status SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid = false);
#endif


#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
MemoryProfiler& GetMemoryProfiler() {
return memory_profiler_;
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/session/inference_session_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ static Status SetEnableProfiling(SessionOptions& session_options,
return Status::OK();
}

// This function is called by nlohmann/json
void from_json(const json& j, TuningResults& trs) {
j.at("ep").get_to(trs.ep);
j.at("results").get_to(trs.results);
j.at("validators").get_to(trs.validators);
}

//---------------------------------------------------
//--- end of session options related helpers ---
//---------------------------------------------------
Expand Down Expand Up @@ -227,6 +234,29 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options
"Parsing RunOptions from ModelProto is not supported yet");
}

Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
std::vector<TuningResults>& results) {
results.clear();
auto it = metadata.custom_metadata_map.find(kTuningResultsKeys);
if (it == metadata.custom_metadata_map.end()) {
return Status::OK();
}

LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model";

ORT_TRY {
auto parsed_tuning_results_json = json::parse(it->second);
results = parsed_tuning_results_json.get<std::vector<TuningResults>>();
}
ORT_CATCH(const std::exception& e) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, FAIL,
"Tuning results stored in the model file cannot be parsed. Error message: ", e.what(), ". Ignoring...");
}

return Status::OK();
}

} // namespace inference_session_utils
} // namespace onnxruntime

Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/session/inference_session_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class JsonConfigParser {

Status ParseRunOptionsFromModelProto(/*out*/ RunOptions& run_options);

Status ParseTuningResultsFromModelProto(/*out*/ std::vector<TuningResults>& results);

private:
// Logger instance that will be used to log events along the parsing steps
const logging::Logger& logger_;
Expand All @@ -60,6 +58,9 @@ class JsonConfigParser {
bool is_ort_config_json_available_ = false;
};

Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata,
std::vector<TuningResults>& results);

#endif // !defined(ORT_MINIMAL_BUILD)

} // namespace inference_session_utils
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def run_with_iobinding(self, iobinding, run_options=None):
"""
self._sess.run_with_iobinding(iobinding._iobinding, run_options)

def get_tuning_results(self):
return self._sess.get_tuning_results()

def set_tuning_results(self, results, error_on_invalid=False):
return self._sess.set_tuning_results(results, error_on_invalid)

def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices):
"""
Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,45 @@ including arg name, arg type (contains both type and shape).)pbdoc")
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
py::list ret;
for (const auto& trs : sess->GetSessionHandle()->GetTuningResults()) {
py::dict py_trs;
py_trs["ep"] = trs.ep;
py_trs["results"] = trs.results;
py_trs["validators"] = trs.validators;
ret.append(std::move(py_trs));
}

return ret;
})
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
std::vector<TuningResults> tuning_results;
for (auto handle: results) {
auto py_trs = handle.cast<py::dict>();
TuningResults trs;
trs.ep = py_trs["ep"].cast<py::str>();

for (const auto& [py_op_sig, py_kernel_map]: py_trs["results"].cast<py::dict>()) {
KernelMap kernel_map;
for (const auto& [py_params_sig, py_kernel_id]: py_kernel_map.cast<py::dict>()) {
kernel_map[py_params_sig.cast<py::str>()] = py_kernel_id.cast<py::int_>();
}
trs.results[py_op_sig.cast<py::str>()] = kernel_map;
}

for (const auto& [k, v]: py_trs["validators"].cast<py::dict>()) {
trs.validators[k.cast<py::str>()] = v.cast<py::str>();
}

tuning_results.emplace_back(std::move(trs));
}

Status status = sess->GetSessionHandle()->SetTuningResults(tuning_results, error_on_invalid);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
});

py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/test/framework/tunable_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class TestTuningResultsValidator : public TuningResultsValidator {

class TestTuningContext : public ITuningContext {
public:
using ITuningContext::ITuningContext;

void EnableTunableOp() override { tuning_enabled_ = true; }
void DisableTunableOp() override { tuning_enabled_ = false; }
bool IsTunableOpEnabled() const override { return tuning_enabled_; }
Expand All @@ -41,6 +43,8 @@ class TestTuningContext : public ITuningContext {

const TuningResultsValidator& GetTuningResultsValidator() const override { return validator_; }

void ClearCache() { manager_.Clear(); }

private:
bool tuning_enabled_{false};
TuningResultsManager manager_{};
Expand All @@ -49,7 +53,7 @@ class TestTuningContext : public ITuningContext {

class TestEP : public IExecutionProvider {
static constexpr const char* kEPType = "TestEP";
TestTuningContext tuning_ctx_{};
TestTuningContext tuning_ctx_{this};

public:
TestEP() : IExecutionProvider{kEPType, true} {}
Expand All @@ -58,6 +62,7 @@ class TestEP : public IExecutionProvider {
return const_cast<TestTuningContext*>(&tuning_ctx_);
}

void ClearCache() { tuning_ctx_.ClearCache(); }
};

class TestTimer : public ITimer<StreamT> {
Expand Down

0 comments on commit 05bb29f

Please sign in to comment.