Skip to content

Commit

Permalink
Add TuningResultsValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Feb 3, 2023
1 parent beee0f3 commit 8eae816
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 3 deletions.
29 changes: 29 additions & 0 deletions onnxruntime/core/framework/tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class ITuningContext {

virtual TuningResultsManager& GetTuningResultsManager() = 0;
virtual const TuningResultsManager& GetTuningResultsManager() const = 0;

virtual const TuningResultsValidator& GetTuningResultsValidator() const = 0;
};

class TuningResultsManager {
Expand All @@ -50,4 +52,31 @@ class TuningResultsManager {
std::unordered_map<std::string, KernelMap> results_;
};

class TuningResultsValidator {
public:
using GetFunc = std::function<std::string()>;
using ValidateFunc = std::function<Status(const std::string&)>;
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;

TuningResultsValidator();

std::unordered_map<std::string, std::string> GetAllValidators() const;
Status ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;

protected:
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);

virtual std::string GetOrtVersion() const;
virtual Status ValidateOrtVersion(const std::string& value) const;

virtual std::string GetOrtGitCommit() const;
virtual Status ValidateOrtGitCommit(const std::string& value) const;

virtual std::string GetOrtBuildConfig() const;
virtual Status ValidateOrtBuildConfig(const std::string& value) const;

private:
GetValidateFuncs validators_;
};

} // namespace onnxruntime
142 changes: 141 additions & 1 deletion onnxruntime/core/framework/tuning_context_impl.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file contains the implementation of TuningResultsManager. At the moment, there is no necessity to expose these
// This file contains the implementation of TuningContext. At the moment, there is no necessity to expose these
// methods as OrtApis. This will cause missing symbols when loading provider dynamic libraries, because the libraries
// are not whole-archive linked and these symbols are not referenced at framework level. To circumvent this problem,
// the EP must has and only has one translation unit include this file.
Expand All @@ -11,6 +11,10 @@

#pragma once

#include <functional>
#include <unordered_set>
#include <utility>

#include "core/framework/tunable.h"
#include "core/framework/tuning_context.h"
#include "core/framework/tuning_results.h"
Expand Down Expand Up @@ -106,4 +110,140 @@ void TuningResultsManager::Clear() {
results_ = {};
}

Status CheckMandatoryKeys(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
const std::unordered_map<std::string, std::string>& to_check) {
constexpr const std::array mandatory_keys{"ORT_VERSION", "ORT_GIT_COMMIT", "ORT_BUILD_CONFIG"};

bool passed = true;
std::ostringstream oss;
for (const auto& k : mandatory_keys) {
if (gv_funcs.find(k) == gv_funcs.end()) {
passed = false;
oss << "key=\"" << k << "\" is not registered for Get and Validate. ";
}

if (to_check.find(k) == to_check.end()) {
passed = false;
oss << "key=\"" << k << "\" is not provided for validation. ";
}
}
ORT_RETURN_IF(!passed, oss.str());
return Status::OK();
}

Status CheckKeysMatching(
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
const std::unordered_map<std::string, std::string>& to_check) {
auto get_keys = [](const auto& it) -> std::string { return it.first; };
std::vector<std::string> required_keys;
std::vector<std::string> provided_keys;
std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys);
std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys);
std::sort(required_keys.begin(), required_keys.end());
std::sort(provided_keys.begin(), provided_keys.end());

std::unordered_set<std::string> intersection;
std::set_intersection(required_keys.cbegin(), required_keys.cend(),
provided_keys.cbegin(), provided_keys.cend(),
std::inserter(intersection, intersection.end()));
bool matched = true;
std::ostringstream oss;
if (intersection.size() != required_keys.size()) {
matched = false;
for (const auto& k : required_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is required, but the tuning results does not provide it. ";
}
}
}
if (intersection.size() != provided_keys.size()) {
matched = false;
for (const auto& k : provided_keys) {
if (intersection.find(k) == intersection.end()) {
oss << "Unmatched validator: \"" << k << "\" is provided, but onnxruntime is unable to consume it. ";
}
}
}
ORT_RETURN_IF(!matched, oss.str());
return Status::OK();
}

std::string TuningResultsValidator::GetOrtVersion() const {
return ORT_VERSION;
}

Status TuningResultsValidator::ValidateOrtVersion(const std::string& value) const {
ORT_RETURN_IF(value != ORT_VERSION, "onnxruntime version mismatch");
return Status::OK();
}

std::string TuningResultsValidator::GetOrtGitCommit() const {
// TODO:
return "";
}

Status TuningResultsValidator::ValidateOrtGitCommit(const std::string& value) const {
// TODO:
ORT_UNUSED_PARAMETER(value);
return Status::OK();
}

std::string TuningResultsValidator::GetOrtBuildConfig() const {
return "";
}

Status TuningResultsValidator::ValidateOrtBuildConfig(const std::string& value) const {
auto current = GetOrtBuildConfig();
ORT_RETURN_IF(current != value,
"onnxruntime building configuration mismatch: tuning results produced with library \"",
value, "\", current library built with \"", current, "\"");
return Status::OK();
}

TuningResultsValidator::TuningResultsValidator() {
RegisterValidator(
"ORT_VERSION",
[this]() { return GetOrtVersion(); },
[this](auto&& k) { return ValidateOrtVersion(std::forward<decltype(k)>(k)); });

RegisterValidator(
"ORT_GIT_COMMIT",
[this]() { return GetOrtGitCommit(); },
[this](auto&& k) { return ValidateOrtGitCommit(std::forward<decltype(k)>(k)); });

RegisterValidator(
"ORT_BUILD_CONFIG",
[this]() { return GetOrtBuildConfig(); },
[this](auto&& k) { return ValidateOrtBuildConfig(std::forward<decltype(k)>(k)); });
}

Status TuningResultsValidator::ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const {
ORT_RETURN_IF_ERROR(CheckMandatoryKeys(validators_, to_validate));
ORT_RETURN_IF_ERROR(CheckKeysMatching(validators_, to_validate));

for (const auto& [key, value] : to_validate) {
const auto& it = validators_.find(key);
ORT_ENFORCE(it != validators_.cend());
const ValidateFunc& validator = it->second.second;
ORT_RETURN_IF_ERROR(validator(value));
}

return Status::OK();
}

std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
std::unordered_map<std::string, std::string> ret;
for (const auto& [key, get_validate_func_pair] : validators_) {
const GetFunc& getter = get_validate_func_pair.first;
ret[key] = getter();
}
return ret;
}

void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) {
ORT_ENFORCE(validators_.find(key) == validators_.end());
validators_[key] = std::make_pair(gf, vf);
}

} // namespace onnxruntime
38 changes: 37 additions & 1 deletion onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,39 @@ namespace onnxruntime {
namespace cuda {
namespace tunable {

CudaTuningContext::CudaTuningContext(CUDAExecutionProvider*, TunableOpInfo* info) : info_(info) {}
std::string GetCudaVersion() {
int version;
CUDA_CALL_THROW(cudaRuntimeGetVersion(&version));
return std::to_string(version);
}

Status ValidateCudaVersion(const std::string& value) {
auto current = GetCudaVersion();
ORT_RETURN_IF(current != value, "CUDA runtime version mismatch: tuning results produced with CUDA ", value,
", onnxruntime currently run with CUDA ", current);
return Status::OK();
}

std::string CudaTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}

Status CudaTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}

CudaTuningResultsValidator::CudaTuningResultsValidator(CUDAExecutionProvider* ep) : ep_(ep) {
RegisterValidator("CUDA_VERSION", GetCudaVersion, ValidateCudaVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}

CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {}

void CudaTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider";
Expand All @@ -38,6 +70,10 @@ const TuningResultsManager& CudaTuningContext::GetTuningResultsManager() const {
return manager_;
}

const TuningResultsValidator& CudaTuningContext::GetTuningResultsValidator() const {
return validator_;
}

} // namespace tunable
} // namespace cuda
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class CUDAExecutionProvider;
namespace cuda {
namespace tunable {

class CudaTuningResultsValidator : public TuningResultsValidator {
public:
CudaTuningResultsValidator(CUDAExecutionProvider* ep);

protected:
std::string GetDeviceModel() const;
Status ValidateDeviceModel(const std::string& value) const;

private:
CUDAExecutionProvider* ep_; // non-owning handle
};

class CudaTuningContext : public ITuningContext {
public:
explicit CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* info);
Expand All @@ -26,9 +38,12 @@ class CudaTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;

const TuningResultsValidator& GetTuningResultsValidator() const override;

private:
TunableOpInfo* info_; // non-owning handle
TuningResultsManager manager_;
CudaTuningResultsValidator validator_;
};

} // namespace tunable
Expand Down
64 changes: 63 additions & 1 deletion onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,65 @@ namespace onnxruntime {
namespace rocm {
namespace tunable {

RocmTuningContext::RocmTuningContext(ROCMExecutionProvider*, TunableOpInfo* info) : info_(info) {}
std::string GetHipVersion() {
int version;
HIP_CALL_THROW(hipRuntimeGetVersion(&version));
return std::to_string(version);
}

Status ValidateHipVersion(const std::string& value) {
auto current = GetHipVersion();
ORT_RETURN_IF(current != value, "HIP runtime version mismatch: tuning results produced with HIP ", value,
", onnxruntime currently run with HIP ", current);
return Status::OK();
}

std::string GetRocBlasVersion() {
char buf[64];
ROCBLAS_CALL_THROW(rocblas_get_version_string(buf, 256));
buf[63] = '\0';
return buf;
}

Status ValidateRocBlasVersion(const std::string& value) {
auto current = GetRocBlasVersion();
ORT_RETURN_IF(current != value, "rocblas runtime version mismatch: tuning results produced with rocblas ", value,
", onnxruntime currently run with rocblas ", current);
return Status::OK();
}

std::string RocmTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}

Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}

RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} {
RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion);
RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}

std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
std::ostringstream oss;
oss << "USE_CK=" << USE_COMPOSABLE_KERNEL << "|";
#ifdef USE_ROCBLAS_EXTENSION_API
oss << "USE_ROCBLAS_EXTENSION_API=" << 1 << "|";
#else
oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|";
#endif
return oss.str();
}

RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : info_(info), validator_(ep) {}

void RocmTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider";
Expand All @@ -38,6 +96,10 @@ const TuningResultsManager& RocmTuningContext::GetTuningResultsManager() const {
return manager_;
}

const TuningResultsValidator& RocmTuningContext::GetTuningResultsValidator() const {
return validator_;
}

} // namespace tunable
} // namespace rocm
} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ class ROCMExecutionProvider;
namespace rocm {
namespace tunable {

class RocmTuningResultsValidator : public TuningResultsValidator {
public:
RocmTuningResultsValidator(ROCMExecutionProvider* ep);

protected:
std::string GetOrtBuildConfig() const override;

std::string GetDeviceModel() const;
Status ValidateDeviceModel(const std::string& value) const;

private:
ROCMExecutionProvider* ep_; // non-owning handle
};

class RocmTuningContext : public ITuningContext {
public:
explicit RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info);
Expand All @@ -26,9 +40,12 @@ class RocmTuningContext : public ITuningContext {
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;

const TuningResultsValidator& GetTuningResultsValidator() const override;

private:
TunableOpInfo* info_; // non-owning handle
TuningResultsManager manager_;
RocmTuningResultsValidator validator_;
};

} // namespace tunable
Expand Down
Loading

0 comments on commit 8eae816

Please sign in to comment.