diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 3d03b2b4efd5d..d2da1752292ba 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -108,18 +108,25 @@ Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchema const auto& domain = opSet.domain(); const auto version = opSet.version(); // empty domain and 'ai.onnx' are equivalent - if ((domain.empty() || domain == "ai.onnx") && version < 7) { + if ((domain.empty() || domain == kOnnxDomainAlias) && version < 7) { // TODO: Check if we can upgrade all the current opset 6 models that are being tested // in CI to opset 7 or above LOGS_DEFAULT(WARNING) << "ONNX Runtime only *guarantees* support for models stamped " "with opset version 7 or above for opset domain 'ai.onnx'. " "Please upgrade your model to opset 7 or higher. " "For now, this opset " - << version + << version << " model may run depending upon legacy support " "of some older opset version operators."; } - domain_to_version[domain] = gsl::narrow_cast(version); + // We need to overwrite the domain here with ("") or else the loop below will try to find ("") + // in the map and if not found (when domain == kOnnxDomainAlias), adds an entry for ("", 11). + // This effectively ignores the opset version specified by the model for the onnx domain. + if (domain == kOnnxDomainAlias) { + domain_to_version[kOnnxDomain] = gsl::narrow_cast(version); + } else { + domain_to_version[domain] = gsl::narrow_cast(version); + } } auto domain_map = schema_registry->GetLatestOpsetVersions(false); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 9d4f2e419371c..9c273d846bf07 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -122,7 +122,7 @@ class FuseExecutionProvider : public IExecutionProvider { class InferenceSessionGetGraphWrapper : public InferenceSession { public: explicit InferenceSessionGetGraphWrapper(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { + logging::LoggingManager* logging_manager) : InferenceSession(session_options, logging_manager) { } const Graph& GetGraph() { @@ -364,7 +364,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { InferenceSessionGetGraphWrapper session_object{so, &DefaultLoggingManager()}; ASSERT_TRUE(session_object.Load(test_model).IsOK()); ASSERT_TRUE(session_object.Initialize().IsOK()); - + // Assert that model has been transformed and identity Node is removed. const auto& graph = session_object.GetGraph(); std::map op_to_count = CountOpsInGraph(graph); @@ -383,7 +383,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { InferenceSession session_object_opt{so_opt, &DefaultLoggingManager()}; ASSERT_TRUE(session_object_opt.Load(so.optimized_model_filepath).IsOK()); ASSERT_TRUE(session_object_opt.Initialize().IsOK()); - + // Assert that re-feed of optimized model with default transform level results // in same runtime model as abs-id-max.onnx with TransformLevel-1. std::ifstream model_fs_session1(so.optimized_model_filepath, ios::in | ios::binary); @@ -1481,5 +1481,16 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) { #endif +TEST(InferenceSessionTests, ModelWithKOnnxDomainAlias) { + SessionOptions so; + so.session_logid = "InferenceSessionTests.NoTimeout"; + InferenceSession session_object{so, &DefaultLoggingManager()}; + std::string file_name = "testdata/test_model_with_fullonnxdomain.onnx"; + auto ret_status = session_object.Load(file_name); + ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage(); + ret_status = session_object.Initialize(); + ASSERT_TRUE(ret_status.IsOK()) << ret_status.ErrorMessage(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx b/onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx new file mode 100644 index 0000000000000..fe97d50aa5529 --- /dev/null +++ b/onnxruntime/test/testdata/test_model_with_fullonnxdomain.onnx @@ -0,0 +1,18 @@ + onnx-example"ai.onnx:j + +X1 +X2Y"Equal:ai.onnx +test-modelZ +X1 +  + +Z +X2 +  + +b +Y +   + +B +ai.onnx \ No newline at end of file