Skip to content

Commit

Permalink
merge branch user/sheil/winmladapter_c_api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 committed Jan 18, 2020
2 parents 044c861 + 18412d3 commit e519d73
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 28 deletions.
18 changes: 15 additions & 3 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ struct DMLProviderFactory : IExecutionProviderFactory {
~DMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
void SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode);

private:
ComPtr<IDMLDevice> dml_device_{};
ComPtr<ID3D12CommandQueue> cmd_queue_{};
AllocatorRoundingMode rounding_mode_ = AllocatorRoundingMode::Enabled;
};

std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
return Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get());
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get());
Dml::SetDefaultRoundingMode(provider.get(), rounding_mode_);
return provider;
}

void DMLProviderFactory::SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode) {
rounding_mode_ = rounding_mode;
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(IDMLDevice* dml_device,
Expand All @@ -55,6 +63,11 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(ID
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue);
}

void DmlConfigureProviderFactoryDefaultRoundingMode(IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode) {
auto dml_prvider_factory = static_cast<DMLProviderFactory*>(factory);
dml_prvider_factory->SetDefaultRoundingMode(rounding_mode);
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id) {
ComPtr<IDXGIFactory4> dxgi_factory;
THROW_IF_FAILED(CreateDXGIFactory2(0, IID_PPV_ARGS(&dxgi_factory)));
Expand All @@ -77,7 +90,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
#if _DEBUG
ComPtr<ID3D12DebugDevice> debug_device;
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
(void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure
const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr);

if (is_d3d12_debug_layer_enabled) {
Expand All @@ -91,7 +104,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
DML_FEATURE_LEVEL_2_0,
IID_PPV_ARGS(&dml_device)));


return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
}

Expand Down
4 changes: 2 additions & 2 deletions winml/adapter/winml_adapter_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
&winmla::DmlExecutionProviderFlushContext,
&winmla::DmlExecutionProviderTrimUploadHeap,
&winmla::DmlExecutionProviderReleaseCompletedReferences,

&winmla::DmlExecutionProviderReleaseCompletedReferences,
&winmla::DmlCreateGPUAllocationFromD3DResource,
&winmla::DmlFreeGPUAllocation,
&winmla::DmlGetD3D12ResourceFromAllocation,
&winmla::DmlCopyTensor,

&winmla::GetProviderMemoryInfo,
&winmla::GetProviderAllocator,
&winmla::FreeProviderAllocator,
Expand Down
17 changes: 16 additions & 1 deletion winml/adapter/winml_adapter_dml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "winml_adapter_apis.h"
#include "core/framework/error_code_helper.h"

#include "core/session/abi_session_options_impl.h"
#include "core/providers/dml/dml_provider_factory.h"
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"

Expand Down Expand Up @@ -42,11 +43,25 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
return dmlDevice;
}

namespace onnxruntime {
void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode);
}

ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
ID3D12Device* d3d_device, ID3D12CommandQueue* queue) {
API_IMPL_BEGIN
auto dml_device = CreateDmlDevice(d3d_device);
return OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue);
if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) {
return status;
}
auto factory = options->provider_factories.back().get();

// OnnxRuntime uses the default rounding mode when calling the session's allocator.
// During initialization, OnnxRuntime allocates weights, which are permanent across session
// lifetime and can be large, so shouldn't be rounded.
// So we create the provider with rounding disabled, and expect the caller to enable it after.
onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled);
return nullptr;
API_IMPL_END
}

Expand Down
42 changes: 41 additions & 1 deletion winml/adapter/winml_adapter_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,52 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi

ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> inference_session;
try {
// Create the inference session
*session = reinterpret_cast<OrtSession*>(new onnxruntime::InferenceSession(options->value, env->GetLoggingManager()));
inference_session = std::make_unique<onnxruntime::InferenceSession>(options->value, env->GetLoggingManager());
} catch (const std::exception& e) {
return OrtApis::CreateStatus(ORT_FAIL, e.what());
}

// we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of
// byte addressable memory
std::vector<std::unique_ptr<onnxruntime::IExecutionProvider>> provider_list;
if (options) {
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider();
if (provider->Type() == onnxruntime::kDmlExecutionProvider) {
if (options->value.enable_mem_pattern) {
// TODO Instead of returning an error, should we set mem pattern to false here and log a warning saying so?
// Doing so would be inconsistent with the Python API that doesn't go through this code path.
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Mem pattern should be disabled when using DML execution provider.");
}
if (options->value.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sequential execution should be enabled when using DML execution provider.");
}
}
provider_list.push_back(std::move(provider));
}
}

Status status;
if (options) {
if (!options->custom_op_domains_.empty()) {
status = inference_session->AddCustomOpDomains(options->custom_op_domains_);
if (!status.IsOK())
return onnxruntime::ToOrtStatus(status);
}
}

// register the providers
for (auto& provider : provider_list) {
if (provider) {
inference_session->RegisterExecutionProvider(std::move(provider));
}
}

*session = reinterpret_cast<OrtSession*>(inference_session.release());

return nullptr;
API_IMPL_END
}
Expand Down
25 changes: 13 additions & 12 deletions winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ OnnxruntimeDmlSessionBuilder::CreateSessionOptions(
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()),
ort_api);

// Request the cpu ep as well.... todo check if we need this
// winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), true);
#ifndef _WIN64
auto use_arena = false;
#else
auto use_arena = true;
#endif
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena),
ort_api);

// call release() so the underlying OrtSessionOptions object isn't freed
*options = session_options.release();
Expand Down Expand Up @@ -76,24 +81,20 @@ HRESULT OnnxruntimeDmlSessionBuilder::Initialize(
OrtSession* session) {
RETURN_HR_IF_NULL(E_INVALIDARG, session);
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

size_t num_providers;
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers),

RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionInitialize(session),
engine_factory_->UseOrtApi());
RETURN_HR_IF(E_UNEXPECTED, num_providers != 2);

OrtExecutionProvider* ort_provider;
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider),
engine_factory_->UseOrtApi());

// OnnxRuntime uses the default rounding mode when calling the session's allocator.
// During initialization, OnnxRuntime allocates weights, which are permanent across session
// lifetime and can be large, so shouldn't be rounded.
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, false),

size_t num_providers;
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers),
engine_factory_->UseOrtApi());
RETURN_HR_IF(E_UNEXPECTED, num_providers != 2);

RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionInitialize(session),
engine_factory_->UseOrtApi());
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true),
engine_factory_->UseOrtApi());

Expand Down
22 changes: 19 additions & 3 deletions winml/lib/Api.Ort/OnnxruntimeEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,24 @@ HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() {
return S_OK;
}

HRESULT OnnxruntimeEngine::CopyOneInputAcrossDevices(const char* input_name, const IValue* src, IValue** dest) {
return E_NOTIMPL;
HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) {
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

OrtExecutionProvider* ort_provider;
winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider);

auto src_value = static_cast<OnnxruntimeValue*>(src);
auto dest_value = static_cast<OnnxruntimeValue*>(dest);

bool is_empty;
auto has_null_source = (SUCCEEDED(src_value->IsEmpty(&is_empty)) && is_empty);
RETURN_HR_IF(E_FAIL, has_null_source);

auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty);
RETURN_HR_IF(E_FAIL, has_null_dest);

winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue());
return S_OK;
}

HRESULT OnnxruntimeEngine::Sync() {
Expand Down Expand Up @@ -479,7 +495,7 @@ HRESULT OnnxruntimeEngine::CreateNullValue(_Out_ IValue** out) {
return S_OK;
}

HRESULT OnnxruntimeEngine::CopyOneInputAcrossDevices(const char* name, IValue* src, IValue** out) {
HRESULT OnnxruntimeEngine::CreateOneInputAcrossDevices(const char* name, IValue* src, IValue** out) {
auto ort_api = engine_factory_->UseOrtApi();
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

Expand Down
4 changes: 2 additions & 2 deletions winml/lib/Api.Ort/OnnxruntimeEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass<
STDMETHOD(FlushContext)() override;
STDMETHOD(TrimUploadHeap)() override;
STDMETHOD(ReleaseCompletedReferences)() override;
STDMETHOD(CopyOneInputAcrossDevices)(const char* input_name, const IValue* src, IValue** dest) override;
STDMETHOD(Sync)() override;
STDMETHOD(CreateTensorValue)(int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override;
STDMETHOD(CreateTensorValueFromExternalD3DResource)(ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override;
STDMETHOD(CreateTensorValueFromExternalBuffer)(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override;
STDMETHOD(CreateNullValue)(_Out_ IValue** out) override;
STDMETHOD(CopyOneInputAcrossDevices)(const char* name, IValue* src, IValue** out) override;
STDMETHOD(CreateOneInputAcrossDevices)(const char* name, IValue* src, IValue** dest) override;
STDMETHOD(CopyValueAcrossDevices)(IValue* src, IValue* dest) override;
STDMETHOD(Run)(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) override;

OrtSession* UseOrtSession();
Expand Down
2 changes: 1 addition & 1 deletion winml/lib/Api/LearningModelBinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr<

auto engine = m_session.as<LearningModelSession>()->GetEngine();
winrt::com_ptr<WinML::IValue> device_value;
WINML_THROW_IF_FAILED(engine->CopyOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch
WINML_THROW_IF_FAILED(engine->CreateOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch

if (exists) {
inputs_[index] = device_value;
Expand Down
2 changes: 1 addition & 1 deletion winml/lib/Api/impl/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ struct TensorBase : TBase {
GetCpuResource()->buffer().second, GetCpuResource()->size_in_bytes(), GetCpuResource()->shape().data(),
GetCpuResource()->shape().size(), TensorKind(), dest.put()),
"Failed to prepare buffer for copy back from device resource.");
//RETURN_IF_FAILED(engine->CopyTensor(value, dest.get()));
RETURN_IF_FAILED(engine->CopyValueAcrossDevices(value, dest.get()));
}

return S_OK;
Expand Down
4 changes: 2 additions & 2 deletions winml/lib/Common/inc/iengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ MIDL_INTERFACE("30c99886-38d2-41cb-a615-203fe7d7daac") IEngine : IUnknown {
STDMETHOD(FlushContext)() PURE;
STDMETHOD(TrimUploadHeap)() PURE;
STDMETHOD(ReleaseCompletedReferences)() PURE;
STDMETHOD(CopyOneInputAcrossDevices)(const char* input_name, const IValue* source, IValue** dest) PURE;
STDMETHOD(Sync)() PURE;
STDMETHOD(CreateTensorValue)(int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE;
STDMETHOD(CreateTensorValueFromExternalD3DResource)(ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE;
STDMETHOD(CreateTensorValueFromExternalBuffer)(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE;
STDMETHOD(CreateNullValue)(_Out_ IValue** out) PURE;
STDMETHOD(CopyOneInputAcrossDevices)(const char* name, IValue* src, IValue** out) PURE;
STDMETHOD(CreateOneInputAcrossDevices)(const char* name, IValue* src, IValue** dest) PURE;
STDMETHOD(CopyValueAcrossDevices)(IValue* src, IValue* dest) PURE;
STDMETHOD(Run)(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) PURE;
};

Expand Down

0 comments on commit e519d73

Please sign in to comment.