Skip to content

Commit

Permalink
Merge two helpers involving the kernel def hashes into one file (#10609)
Browse files Browse the repository at this point in the history
* Merge two helpers involving the kernel def hashes used by ORT format models. Add codeowners entry to ensure updates involving hashes are checked.
  • Loading branch information
skottmckay authored Feb 23, 2022
1 parent ea7f773 commit e0d1d69
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 91 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ samples/python/training/** @thiagocrepaldi @tlh20 @liqunfu @baijumeswani @Sherlo

# Mobile
/onnxruntime/test/testdata/kernel_def_hashes/ @skottmckay @gwang-msft @YUNQIUGUO @edgchen1
/onnxruntime/core/framework/kernel_def_hash_helpers.* @skottmckay @gwang-msft @YUNQIUGUO @edgchen1

# Contrib Ops
onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @zhanghuanrong @chenfucn @yufenglee @yihonglyu @snnn
Expand Down
71 changes: 71 additions & 0 deletions onnxruntime/core/framework/kernel_def_hash_helpers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include "core/framework/kernel_def_hash_helpers.h"

namespace onnxruntime {
namespace utils {
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version) {
// Layout tranformer can add new nodes to the graph.
// Since layout transformation can happen in an extended build, if these nodes are not picked up and compiled by
// NNAPI or other compiling EPs then we need a way to get the hashes for these nodes. Since the infrastructure
// as well as op_schema required to generate these hashes is not available in an extended minimal build,
// we maintain a static map of nodes to hash value. This hash value can then be used to retireive the
// kernel for the given op.
static std::unordered_map<std::string, HashValue> static_kernel_hashes{
{"Transpose_1", 4324835766923221184ULL},
{"Transpose_13", 17267477159887372848ULL},
{"Squeeze_1", 12889825108950034784ULL},
{"Squeeze_11", 14725795030460042064ULL},
{"Squeeze_13", 16122603335179721968ULL},
{"UnSqueeze_1", 15964030255371555232ULL},
{"UnSqueeze_11", 16989589986691430224ULL},
{"UnSqueeze_13", 9466011545409597224ULL},
{"Gather_1", 625186873870077080ULL},
{"Gather_11", 11761559382112736008ULL},
{"Gather_13", 7462749543760614528ULL},
{"Identity_1", 18001636502361632792ULL},
{"Identity_13", 16879814636194901248ULL},
{"Identity_14", 16515685968327103576ULL},
{"Identity_16", 17661628575887109792ULL},
};

auto key = op_type + "_" + std::to_string(since_version);
auto iter = static_kernel_hashes.find(key);
if (iter != static_kernel_hashes.end()) {
return iter->second;
}

return std::nullopt;
}

void UpdateHashForBackwardsCompatibility(HashValue& hash) {
// map of old hash to new hash if we were forced to break backwards compatibility for a kernel registration
//
// If we need to update the hash for an existing registration, an entry needs to be added here to map the
// old hash to the new. This should rarely be required as historically the only need for it was fixing
// kernel registrations with invalid type constraints. Please carefully read through the information at the top of
// onnxruntime/test/providers/kernel_def_hash_test.cc regarding how/when hashes might change and the best way to
// address that.
static const std::unordered_map<HashValue, HashValue> hashes{
// old new domain, operator, opset[, type]
{2832535737534577496ULL, 16708009824840936392ULL}, // kOnnxDomain, Dropout, 7
{12198479371038564912ULL, 1718418059112844640ULL}, // kOnnxDomain, Scan, 9
{2560955351529676608ULL, 3668627007850399040ULL}, // kOnnxDomain, Scan, 11
{10232409728231027688ULL, 5212043150202938416ULL}, // kOnnxDomain, Not, 1
{11912523891622051440ULL, 10225383741733918632ULL}, // kOnnxDomain, RoiAlign, 10, float
{18084231515768318048ULL, 17022700455473327752ULL}, // kOnnxDomain, RoiAlign, 10, double
{14033689580222898712ULL, 634727773751317256ULL}, // kOnnxDomain, GatherND, 11
{646512416908411600ULL, 3064028185911332496ULL}, // kOnnxDomain, GatherND, 12
{15019893097608892000ULL, 11311962292460032936ULL}, // kOnnxDomain, GatherND, 13
{14259324427750852648ULL, 7767393334034626736ULL}, // kOnnxDomain, StringNormalizer, 10
// contrib ops
{7642430665819070720ULL, 8620498355864235632ULL}, // kMSDomain, CropAndResize, 1
{15019666093341768288ULL, 11924582339825775592ULL}}; // kMSDomain, GridSample, 1

auto iter = hashes.find(hash);
if (iter != hashes.cend()) {
// hash was updated in newer version of ORT kernel registrations
hash = iter->second;
}
}

} // namespace utils
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

#pragma once

#include <unordered_map>
#include <string>
#include "core/common/common.h"
#include "core/graph/basic_types.h"

#include <optional>
#include "core/common/basic_types.h"
namespace onnxruntime {
#include <string>

namespace onnxruntime::utils {
/**
* @brief Gets the hash value for provided op type + version combination if it is available, otherwise
* returns a nullopt. The hash value is available if this node was added by layout transformer. For all other
Expand All @@ -16,4 +18,11 @@ namespace onnxruntime {
* @return std::optional<HashValue>
*/
std::optional<HashValue> GetHashValueFromStaticKernelHashMap(const std::string& op_type, int since_version);
} // namespace onnxruntime

/**
* Get replacement hash for backwards compatibility if we had to modify an existing kernel registration.
* @param hash Hash to update if needed.
*/
void UpdateHashForBackwardsCompatibility(HashValue& hash);

} // namespace onnxruntime::utils
6 changes: 3 additions & 3 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
#include "core/common/safeint.h"
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/framework/allocator.h"
#include "core/framework/kernel_def_hash_helpers.h"
#include "core/framework/node_index_info.h"
#include "core/framework/op_kernel.h"
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/session_state_flatbuffers_utils.h"
#include "core/framework/session_state_utils.h"
#include "core/framework/utils.h"
#include "core/framework/static_kernel_def_hashes.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

Expand Down Expand Up @@ -971,7 +971,7 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
auto add_kernel_by_hash =
[&kernel_registry_manager, this](const Node& node, HashValue hash) {
const KernelCreateInfo* kci = nullptr;
fbs::utils::UpdateHashForBackwardsCompatibility(hash);
utils::UpdateHashForBackwardsCompatibility(hash);

ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(hash, &kci),
"Failed to find kernel def hash (", hash, ") in kernel registries for ",
Expand Down Expand Up @@ -1024,7 +1024,7 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
for (const auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.count(node.Index()) == 0) {
if (node.Domain() == kOnnxDomain || node.Domain() == kOnnxDomainAlias) {
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash));
} else {
Expand Down
36 changes: 3 additions & 33 deletions onnxruntime/core/framework/session_state_flatbuffers_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "core/framework/session_state_flatbuffers_utils.h"

#include "core/framework/kernel_def_hash_helpers.h"

namespace onnxruntime::fbs::utils {

std::string GetSubgraphId(const NodeIndex node_idx, const std::string& attr_name) {
Expand Down Expand Up @@ -48,7 +50,7 @@ FbsSessionStateViewer::NodeKernelInfo FbsSessionStateViewer::GetNodeKernelInfo(I
const auto* const fbs_kernel_def_hashes = fbs_kcis->kernel_def_hashes();

HashValue hash = fbs_kernel_def_hashes->Get(idx);
UpdateHashForBackwardsCompatibility(hash);
onnxruntime::utils::UpdateHashForBackwardsCompatibility(hash);

return {fbs_node_indices->Get(idx), hash};
}
Expand Down Expand Up @@ -76,36 +78,4 @@ Status FbsSessionStateViewer::GetSubgraphSessionState(NodeIndex node_idx, const
fbs_subgraph_session_state_out = fbs_subgraph_session_state;
return Status::OK();
}

void UpdateHashForBackwardsCompatibility(HashValue& hash) {
// map of old hash to new hash if we were forced to break backwards compatibility for a kernel registration
//
// If we need to update the hash for an existing registration, an entry needs to be added here to map the
// old hash to the new. This should rarely be required as historically the only need for it was fixing
// kernel registrations with invalid type constraints. Please carefully read through the information at the top of
// onnxruntime/test/providers/kernel_def_hash_test.cc regarding how/when hashes might change and the best way to
// address that.
static const std::unordered_map<HashValue, HashValue> hashes{
// old new domain, operator, opset[, type]
{2832535737534577496ULL, 16708009824840936392ULL}, // kOnnxDomain, Dropout, 7
{12198479371038564912ULL, 1718418059112844640ULL}, // kOnnxDomain, Scan, 9
{2560955351529676608ULL, 3668627007850399040ULL}, // kOnnxDomain, Scan, 11
{10232409728231027688ULL, 5212043150202938416ULL}, // kOnnxDomain, Not, 1
{11912523891622051440ULL, 10225383741733918632ULL}, // kOnnxDomain, RoiAlign, 10, float
{18084231515768318048ULL, 17022700455473327752ULL}, // kOnnxDomain, RoiAlign, 10, double
{14033689580222898712ULL, 634727773751317256ULL}, // kOnnxDomain, GatherND, 11
{646512416908411600ULL, 3064028185911332496ULL}, // kOnnxDomain, GatherND, 12
{15019893097608892000ULL, 11311962292460032936ULL}, // kOnnxDomain, GatherND, 13
{14259324427750852648ULL, 7767393334034626736ULL}, // kOnnxDomain, StringNormalizer, 10
// contrib ops
{7642430665819070720ULL, 8620498355864235632ULL}, // kMSDomain, CropAndResize, 1
{15019666093341768288ULL, 11924582339825775592ULL}}; // kMSDomain, GridSample, 1

auto iter = hashes.find(hash);
if (iter != hashes.cend()) {
// hash was updated in newer version of ORT kernel registrations
hash = iter->second;
}
}

} // namespace onnxruntime::fbs::utils
7 changes: 0 additions & 7 deletions onnxruntime/core/framework/session_state_flatbuffers_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,4 @@ class FbsSessionStateViewer {
private:
const fbs::SessionState& fbs_session_state_;
};

/**
* Get replacement hash for backwards compatibility if we had to modify an existing kernel registration.
* @param hash Hash to update if needed.
*/
void UpdateHashForBackwardsCompatibility(HashValue& hash);

} // namespace onnxruntime::fbs::utils
37 changes: 0 additions & 37 deletions onnxruntime/core/framework/static_kernel_def_hashes.cc

This file was deleted.

4 changes: 2 additions & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/utils.h"
#include "core/framework/static_kernel_def_hashes.h"
#include "core/framework/kernel_def_hash_helpers.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_transformer_utils.h"
Expand Down Expand Up @@ -1220,7 +1220,7 @@ Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs
// The following loop fetches the hash values for these nodes.
for (const auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType().empty()) {
auto kernel_hash = GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
}
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/test/providers/kernel_def_hash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* In the unlikely event that we need to make a change to the kernel def
* hashing that breaks backward compatibility, the expected values may need to
* be updated. You will also need to update UpdateHashForBackwardsCompatibility
* in onnxruntime/core/framework/session_state_flatbuffers_utils.cc, add a
* in onnxruntime/core/framework/kernel_def_hash_helpers.cc, add a
* test model for the operator in question to onnxruntime/test/testdata/ort_backwards_compat
* and update OrtModelOnlyTests.TestBackwardsCompat in onnxruntime/test/framework/ort_model_only_test.cc
* to load the new model and validate the hash replacement works correctly.
Expand Down Expand Up @@ -200,9 +200,12 @@ TEST(KernelDefHashTest, ExpectedCpuKernelDefHashes) {
// Adding this test here because resolution for this test failure requires fetching the hash
// for one of the ops in the list below and this file has information around that.
// Please update the following 3 places:
// 1. api_impl.cc "onnx_ops_available_versions" map, include the latest version in the map
// 2. static_kernel_def_hashes.cc "static_kernel_hashes" include an entry for latest version and it's associated hash
// 3. This file "onnx_ops_available_versions" map, include the latest version in the map
// 1. optimizer/transpose_optimizer/optimizer_api_impl.cc "onnx_ops_available_versions" map,
// include the latest version in the map
// 2. framework/kernel_def_hash_helpers.cc:GetHashValueFromStaticKernelHashMap "static_kernel_hashes" map,
// add an entry for latest version and its associated hash
// 3. KernelDefHashTest.TestNewOpsVersionSupportDuringLayoutTransform "onnx_ops_available_versions" map,
// include the latest version in the map
TEST(KernelDefHashTest, TestNewOpsVersionSupportDuringLayoutTransform) {
static const std::unordered_map<std::string, std::vector<int>> onnx_ops_available_versions = {
{"Squeeze", {1, 11, 13}},
Expand Down

0 comments on commit e0d1d69

Please sign in to comment.