From fd69fd84edccc87a382fae55e8435c8af8322d88 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 17 Jan 2020 13:17:08 -0800 Subject: [PATCH 1/2] enable gpu evals --- .../providers/dml/dml_provider_factory.cc | 18 ++++++-- winml/adapter/winml_adapter_c_api.cpp | 4 +- winml/adapter/winml_adapter_dml.cpp | 17 +++++++- winml/adapter/winml_adapter_session.cpp | 42 ++++++++++++++++++- .../Api.Ort/OnnxruntimeDmlSessionBuilder.cpp | 27 ++++++------ winml/lib/Api.Ort/OnnxruntimeEngine.cpp | 22 ++++++++-- winml/lib/Api.Ort/OnnxruntimeEngine.h | 4 +- winml/lib/Api/LearningModelBinding.cpp | 2 +- winml/lib/Api/impl/TensorBase.h | 2 +- winml/lib/Common/inc/iengine.h | 4 +- 10 files changed, 112 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 5194a4b18f5a4..e05652a33513a 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -24,14 +24,22 @@ struct DMLProviderFactory : IExecutionProviderFactory { ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; + void SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode); private: ComPtr dml_device_{}; ComPtr cmd_queue_{}; + AllocatorRoundingMode rounding_mode_ = AllocatorRoundingMode::Enabled; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - return Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + Dml::SetDefaultRoundingMode(provider.get(), rounding_mode_); + return provider; +} + +void DMLProviderFactory::SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode) { + rounding_mode_ = rounding_mode; } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, @@ -55,6 +63,11 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID return std::make_shared(dml_device, cmd_queue); } +void DmlConfigureProviderFactoryDefaultRoundingMode(IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode) { + auto dml_prvider_factory = static_cast(factory); + dml_prvider_factory->SetDefaultRoundingMode(rounding_mode); +} + std::shared_ptr CreateExecutionProviderFactory_DML(int device_id) { ComPtr dxgi_factory; THROW_IF_FAILED(CreateDXGIFactory2(0, IID_PPV_ARGS(&dxgi_factory))); @@ -77,7 +90,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG ComPtr debug_device; - (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure + (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); if (is_d3d12_debug_layer_enabled) { @@ -91,7 +104,6 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dml_device))); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } diff --git a/winml/adapter/winml_adapter_c_api.cpp b/winml/adapter/winml_adapter_c_api.cpp index c54af31513e82..2b827c1ce4a7f 100644 --- a/winml/adapter/winml_adapter_c_api.cpp +++ b/winml/adapter/winml_adapter_c_api.cpp @@ -80,12 +80,12 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = { &winmla::DmlExecutionProviderSetDefaultRoundingMode, &winmla::DmlExecutionProviderFlushContext, &winmla::DmlExecutionProviderTrimUploadHeap, - &winmla::DmlExecutionProviderReleaseCompletedReferences, - + &winmla::DmlExecutionProviderReleaseCompletedReferences, &winmla::DmlCreateGPUAllocationFromD3DResource, &winmla::DmlFreeGPUAllocation, &winmla::DmlGetD3D12ResourceFromAllocation, &winmla::DmlCopyTensor, + &winmla::GetProviderMemoryInfo, &winmla::GetProviderAllocator, &winmla::FreeProviderAllocator, diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp index cf1e7b81445b6..f9a60a35970b4 100644 --- a/winml/adapter/winml_adapter_dml.cpp +++ b/winml/adapter/winml_adapter_dml.cpp @@ -8,6 +8,7 @@ #include "winml_adapter_apis.h" #include "core/framework/error_code_helper.h" +#include "core/session/abi_session_options_impl.h" #include "core/providers/dml/dml_provider_factory.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" @@ -42,10 +43,24 @@ Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { return dmlDevice; } +namespace onnxruntime { +void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode); +} + ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, ID3D12Device* d3d_device, ID3D12CommandQueue* queue) { auto dml_device = CreateDmlDevice(d3d_device); - return OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue); + if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) { + return status; + } + auto factory = options->provider_factories.back().get(); + + // OnnxRuntime uses the default rounding mode when calling the session's allocator. + // During initialization, OnnxRuntime allocates weights, which are permanent across session + // lifetime and can be large, so shouldn't be rounded. + // So we create the provider with rounding disabled, and expect the caller to enable it after. + onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled); + return nullptr; } ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) { diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp index d806ba4491e5d..aedd82b677390 100644 --- a/winml/adapter/winml_adapter_session.cpp +++ b/winml/adapter/winml_adapter_session.cpp @@ -40,12 +40,52 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session) { API_IMPL_BEGIN + std::unique_ptr inference_session; try { // Create the inference session - *session = reinterpret_cast(new onnxruntime::InferenceSession(options->value, env->GetLoggingManager())); + inference_session = std::make_unique(options->value, env->GetLoggingManager()); } catch (const std::exception& e) { return OrtApis::CreateStatus(ORT_FAIL, e.what()); } + + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + if (provider->Type() == onnxruntime::kDmlExecutionProvider) { + if (options->value.enable_mem_pattern) { + // TODO Instead of returning an error, should we set mem pattern to false here and log a warning saying so? + // Doing so would be inconsistent with the Python API that doesn't go through this code path. + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Mem pattern should be disabled when using DML execution provider."); + } + if (options->value.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sequential execution should be enabled when using DML execution provider."); + } + } + provider_list.push_back(std::move(provider)); + } + } + + Status status; + if (options) { + if (!options->custom_op_domains_.empty()) { + status = inference_session->AddCustomOpDomains(options->custom_op_domains_); + if (!status.IsOK()) + return onnxruntime::ToOrtStatus(status); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + inference_session->RegisterExecutionProvider(std::move(provider)); + } + } + + *session = reinterpret_cast(inference_session.release()); + return nullptr; API_IMPL_END } diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp index 1362929e7f50c..425c7de5bbf3c 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -40,8 +40,12 @@ OnnxruntimeDmlSessionBuilder::CreateSessionOptions( // Request the dml ep winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()); - // Request the cpu ep as well.... todo check if we need this - // winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), true); +#ifndef _WIN64 + auto use_arena = false; +#else + auto use_arena = true; +#endif + winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena); // call release() so the underlying OrtSessionOptions object isn't freed *options = session_options.release(); @@ -73,23 +77,18 @@ HRESULT OnnxruntimeDmlSessionBuilder::Initialize( OrtSession* session) { RETURN_HR_IF_NULL(E_INVALIDARG, session); auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + if (auto status = winml_adapter_api->SessionInitialize(session)) { + return E_FAIL; + } + OrtExecutionProvider* ort_provider; + winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider); + size_t num_providers; winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers); RETURN_HR_IF(E_UNEXPECTED, num_providers != 2); - OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider); - - // OnnxRuntime uses the default rounding mode when calling the session's allocator. - // During initialization, OnnxRuntime allocates weights, which are permanent across session - // lifetime and can be large, so shouldn't be rounded. - winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, false); - - if (auto status = winml_adapter_api->SessionInitialize(session)) { - return E_FAIL; - } - winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true); // Flush the D3D12 work from the DML execution provider diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index b27b1390306a3..bac7fcaa60d94 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -324,8 +324,24 @@ HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() { return S_OK; } -HRESULT OnnxruntimeEngine::CopyOneInputAcrossDevices(const char* input_name, const IValue* src, IValue** dest) { - return E_NOTIMPL; +HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + + auto src_value = static_cast(src); + auto dest_value = static_cast(dest); + + bool is_empty; + auto has_null_source = (SUCCEEDED(src_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_source); + + auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_dest); + + winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()); + return S_OK; } HRESULT OnnxruntimeEngine::Sync() { @@ -450,7 +466,7 @@ HRESULT OnnxruntimeEngine::CreateNullValue(_Out_ IValue** out) { return S_OK; } -HRESULT OnnxruntimeEngine::CopyOneInputAcrossDevices(const char* name, IValue* src, IValue** out) { +HRESULT OnnxruntimeEngine::CreateOneInputAcrossDevices(const char* name, IValue* src, IValue** out) { auto ort_api = engine_factory_->UseOrtApi(); auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index 0e9f7e3d0fc4e..ef5ca87cefd11 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -71,13 +71,13 @@ class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass< STDMETHOD(FlushContext)() override; STDMETHOD(TrimUploadHeap)() override; STDMETHOD(ReleaseCompletedReferences)() override; - STDMETHOD(CopyOneInputAcrossDevices)(const char* input_name, const IValue* src, IValue** dest) override; STDMETHOD(Sync)() override; STDMETHOD(CreateTensorValue)(int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; STDMETHOD(CreateTensorValueFromExternalD3DResource)(ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; STDMETHOD(CreateTensorValueFromExternalBuffer)(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; STDMETHOD(CreateNullValue)(_Out_ IValue** out) override; - STDMETHOD(CopyOneInputAcrossDevices)(const char* name, IValue* src, IValue** out) override; + STDMETHOD(CreateOneInputAcrossDevices)(const char* name, IValue* src, IValue** dest) override; + STDMETHOD(CopyValueAcrossDevices)(IValue* src, IValue* dest) override; STDMETHOD(Run)(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) override; OrtSession* UseOrtSession(); diff --git a/winml/lib/Api/LearningModelBinding.cpp b/winml/lib/Api/LearningModelBinding.cpp index f031b707fc0b6..8d662fc536148 100644 --- a/winml/lib/Api/LearningModelBinding.cpp +++ b/winml/lib/Api/LearningModelBinding.cpp @@ -470,7 +470,7 @@ HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr< auto engine = m_session.as()->GetEngine(); winrt::com_ptr device_value; - WINML_THROW_IF_FAILED(engine->CopyOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch + WINML_THROW_IF_FAILED(engine->CreateOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch if (exists) { inputs_[index] = device_value; diff --git a/winml/lib/Api/impl/TensorBase.h b/winml/lib/Api/impl/TensorBase.h index 3f040c32ab90f..d4691f7000317 100644 --- a/winml/lib/Api/impl/TensorBase.h +++ b/winml/lib/Api/impl/TensorBase.h @@ -281,7 +281,7 @@ struct TensorBase : TBase { GetCpuResource()->buffer().second, GetCpuResource()->size_in_bytes(), GetCpuResource()->shape().data(), GetCpuResource()->shape().size(), TensorKind(), dest.put()), "Failed to prepare buffer for copy back from device resource."); - //RETURN_IF_FAILED(engine->CopyTensor(value, dest.get())); + RETURN_IF_FAILED(engine->CopyValueAcrossDevices(value, dest.get())); } return S_OK; diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h index 0f22015151503..af755965fc26e 100644 --- a/winml/lib/Common/inc/iengine.h +++ b/winml/lib/Common/inc/iengine.h @@ -46,13 +46,13 @@ MIDL_INTERFACE("30c99886-38d2-41cb-a615-203fe7d7daac") IEngine : IUnknown { STDMETHOD(FlushContext)() PURE; STDMETHOD(TrimUploadHeap)() PURE; STDMETHOD(ReleaseCompletedReferences)() PURE; - STDMETHOD(CopyOneInputAcrossDevices)(const char* input_name, const IValue* source, IValue** dest) PURE; STDMETHOD(Sync)() PURE; STDMETHOD(CreateTensorValue)(int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; STDMETHOD(CreateTensorValueFromExternalD3DResource)(ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; STDMETHOD(CreateTensorValueFromExternalBuffer)(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; STDMETHOD(CreateNullValue)(_Out_ IValue** out) PURE; - STDMETHOD(CopyOneInputAcrossDevices)(const char* name, IValue* src, IValue** out) PURE; + STDMETHOD(CreateOneInputAcrossDevices)(const char* name, IValue* src, IValue** dest) PURE; + STDMETHOD(CopyValueAcrossDevices)(IValue* src, IValue* dest) PURE; STDMETHOD(Run)(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) PURE; }; From 18412d31cf128b357396cd12bb038224a81124b8 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Fri, 17 Jan 2020 15:57:35 -0800 Subject: [PATCH 2/2] wrapper all existing winml adapter apis with API_IMPL to try catch (#2854) --- .../core/framework/onnxruntime_typeinfo.cc | 7 ++++ winml/adapter/winml_adapter_dml.cpp | 2 ++ winml/adapter/winml_adapter_environment.cpp | 6 +++- winml/adapter/winml_adapter_map_type_info.cpp | 4 +++ winml/adapter/winml_adapter_model.cpp | 36 +++++++++++++++++++ .../winml_adapter_sequence_type_info.cpp | 2 ++ 6 files changed, 56 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 06c8d8d232d24..10084c0855261 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,6 +10,7 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" #include "../../winml/adapter/winml_adapter_map_type_info.h" @@ -61,19 +62,25 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType } ORT_API_STATUS_IMPL(winmla::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) { + API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info_ : nullptr; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) { + API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info_ : nullptr; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) { + API_IMPL_BEGIN *out = type_info->denotation_.c_str(); *len = type_info->denotation_.size(); return nullptr; + API_IMPL_END } ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp index f9a60a35970b4..faac6c6535df2 100644 --- a/winml/adapter/winml_adapter_dml.cpp +++ b/winml/adapter/winml_adapter_dml.cpp @@ -49,6 +49,7 @@ void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProvi ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, ID3D12Device* d3d_device, ID3D12CommandQueue* queue) { + API_IMPL_BEGIN auto dml_device = CreateDmlDevice(d3d_device); if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) { return status; @@ -61,6 +62,7 @@ ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ // So we create the provider with rounding disabled, and expect the caller to enable it after. onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) { diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index 34d857391f9be..ca041a485a843 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -49,6 +49,7 @@ class WinmlAdapterLoggingWrapper : public LoggingWrapper { ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN std::string name = logid; std::unique_ptr logger = onnxruntime::make_unique(logging_function, profiling_function, logger_param); @@ -64,6 +65,7 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en // Set a new default logging manager env->SetLoggingManager(std::move(winml_logging_manager)); return nullptr; + API_IMPL_END } // Override select shape inference functions which are incomplete in ONNX with versions that are complete, @@ -72,11 +74,13 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en // registered schema are reachable only after upstream schema have been revised in a later OS release, // which would be a compatibility risk. ORT_API_STATUS_IMPL(winmla::OverrideSchema) { + API_IMPL_BEGIN #ifdef USE_DML static std::once_flag schema_override_once_flag; std::call_once(schema_override_once_flag, []() { SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); }); +#endif USE_DML. return nullptr; -#endif USE_DML + API_IMPL_END } \ No newline at end of file diff --git a/winml/adapter/winml_adapter_map_type_info.cpp b/winml/adapter/winml_adapter_map_type_info.cpp index 0590b408d23c8..622d93acfa817 100644 --- a/winml/adapter/winml_adapter_map_type_info.cpp +++ b/winml/adapter/winml_adapter_map_type_info.cpp @@ -73,12 +73,16 @@ OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { // OrtMapTypeInfo Accessors ORT_API_STATUS_IMPL(winmla::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) { + API_IMPL_BEGIN *out = map_type_info->map_key_type_; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN return map_type_info->map_value_type_->Clone(out); + API_IMPL_END } ORT_API(void, winmla::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) { diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 097d0d30888d9..d68443d689088 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -195,90 +195,118 @@ std::unique_ptr OrtModel::DetachModelProto() { } ORT_API_STATUS_IMPL(winmla::CreateModelFromPath, const char* model_path, size_t size, OrtModel** out) { + API_IMPL_BEGIN if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) { return status; } return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::CreateModelFromData, void* data, size_t size, OrtModel** out) { + API_IMPL_BEGIN if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) { return status; } return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::CloneModel, const OrtModel* in, OrtModel** out) { + API_IMPL_BEGIN auto model_proto_copy = std::make_unique(*in->UseModelProto()); if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) { return status; } return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, const OrtModel* model, const char** const author, size_t* len) { + API_IMPL_BEGIN *author = model->UseModelInfo()->author_.c_str(); *len = model->UseModelInfo()->author_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetName, const OrtModel* model, const char** const name, size_t* len) { + API_IMPL_BEGIN *name = model->UseModelInfo()->name_.c_str(); *len = model->UseModelInfo()->name_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetDomain, const OrtModel* model, const char** const domain, size_t* len) { + API_IMPL_BEGIN *domain = model->UseModelInfo()->domain_.c_str(); *len = model->UseModelInfo()->domain_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetDescription, const OrtModel* model, const char** const description, size_t* len) { + API_IMPL_BEGIN *description = model->UseModelInfo()->description_.c_str(); *len = model->UseModelInfo()->description_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetVersion, const OrtModel* model, int64_t* version) { + API_IMPL_BEGIN *version = model->UseModelInfo()->version_; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN *count = model->UseModelInfo()->model_metadata_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetMetadata, const OrtModel* model, size_t count, const char** const key, size_t* key_len, const char** const value, size_t* value_len) { + API_IMPL_BEGIN *key = model->UseModelInfo()->model_metadata_[count].first.c_str(); *key_len = model->UseModelInfo()->model_metadata_[count].first.size(); *value = model->UseModelInfo()->model_metadata_[count].second.c_str(); *value_len = model->UseModelInfo()->model_metadata_[count].second.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN *count = model->UseModelInfo()->input_features_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN *count = model->UseModelInfo()->output_features_.size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetInputName, const OrtModel* model, size_t index, const char** input_name, size_t* count) { + API_IMPL_BEGIN *input_name = model->UseModelInfo()->input_features_[index]->name().c_str(); *count = model->UseModelInfo()->input_features_[index]->name().size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t index, const char** output_name, size_t* count) { + API_IMPL_BEGIN *output_name = model->UseModelInfo()->output_features_[index]->name().c_str(); *count = model->UseModelInfo()->output_features_[index]->name().size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) { @@ -288,26 +316,33 @@ ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, siz } ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) { + API_IMPL_BEGIN *output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str(); *count = model->UseModelInfo()->output_features_[index]->doc_string().size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) { return status; } return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) { return status; } return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) { + API_IMPL_BEGIN auto model_info = model->UseModelInfo(); auto model_proto = model->UseModelProto(); auto& graph = model_proto->graph(); @@ -372,6 +407,7 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) { } } return nullptr; + API_IMPL_END } ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) { diff --git a/winml/adapter/winml_adapter_sequence_type_info.cpp b/winml/adapter/winml_adapter_sequence_type_info.cpp index 86f6789d5d50d..35bf03c6c0a64 100644 --- a/winml/adapter/winml_adapter_sequence_type_info.cpp +++ b/winml/adapter/winml_adapter_sequence_type_info.cpp @@ -44,7 +44,9 @@ OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { } ORT_API_STATUS_IMPL(winmla::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN return sequence_type_info->sequence_key_type_->Clone(out); + API_IMPL_END } ORT_API(void, winmla::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) {