Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/xianz/winml adapter c api #2869

Merged
1 change: 1 addition & 0 deletions cmake/winml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ add_library(winml_lib_ort STATIC
${winml_lib_api_ort_dir}/OnnxruntimeModel.h
${winml_lib_api_ort_dir}/OnnxruntimeModel.cpp
${winml_lib_api_ort_dir}/OnnxruntimeSessionBuilder.h
${winml_lib_api_ort_dir}/OnnxruntimeErrors.h
${winml_lib_api_ort_dir}/pch.h
)

Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,11 @@ ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t in
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) {
API_IMPL_BEGIN
*input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str();
*count = model->UseModelInfo()->input_features_[index]->doc_string().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) {
Expand Down
22 changes: 13 additions & 9 deletions winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "OnnxruntimeCpuSessionBuilder.h"
#include "OnnxruntimeEngine.h"
#include "OnnxruntimeErrors.h"

using namespace Windows::AI::MachineLearning;

Expand Down Expand Up @@ -40,15 +41,15 @@ OnnxruntimeCpuSessionBuilder::CreateSessionOptions(
#else
auto use_arena = true;
#endif
winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena),
ort_api);

// call release() so the underlying OrtSessionOptions object isn't freed
*options = session_options.release();

return S_OK;
}


HRESULT
OnnxruntimeCpuSessionBuilder::CreateSession(
OrtSessionOptions* options,
Expand All @@ -62,11 +63,13 @@ OnnxruntimeCpuSessionBuilder::CreateSession(
RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env));

OrtSession* ort_session_raw;
winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw),
engine_factory_->UseOrtApi());

auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession);

*session = ort_session.release();

return S_OK;
}

Expand All @@ -76,12 +79,13 @@ OnnxruntimeCpuSessionBuilder::Initialize(
RETURN_HR_IF_NULL(E_INVALIDARG, session);

auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
if (auto status = winml_adapter_api->SessionInitialize(session)) {
return E_FAIL;
}
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionInitialize(session),
engine_factory_->UseOrtApi());

size_t num_providers;
winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers),
engine_factory_->UseOrtApi());

RETURN_HR_IF(E_UNEXPECTED, num_providers != 1);
return S_OK;
}
70 changes: 36 additions & 34 deletions winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "OnnxruntimeEngine.h"

#include "OnnxruntimeErrors.h"

using namespace winrt::Windows::AI::MachineLearning;

// BitmapPixelFormat constants
Expand Down Expand Up @@ -44,12 +46,11 @@ static const char* c_supported_nominal_ranges[] =

namespace Windows::AI::MachineLearning {


// Forward declare CreateFeatureDescriptor
static winml::ILearningModelFeatureDescriptor
CreateFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata);

static TensorKind
Expand Down Expand Up @@ -100,7 +101,9 @@ TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: {
return TensorKind::Complex128;
}
default: { return TensorKind::Undefined; }
default: {
return TensorKind::Undefined;
}
}
}

Expand Down Expand Up @@ -153,7 +156,9 @@ TensorKindToString(TensorKind tensorKind) {
return "complex128";
}
case TensorKind::Undefined:
default: { return "undefined"; }
default: {
Copy link
Member Author

@zhangxiang1993 zhangxiang1993 Jan 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did clang format for these files. #Resolved

return "undefined";
}
}
}

Expand Down Expand Up @@ -310,9 +315,8 @@ GetTensorType(
const std::unordered_map<std::string, std::string>& metadata) {
const char* denotation;
size_t len;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len),
engine_factory->UseOrtApi());

auto has_image_denotation = strncmp(denotation, "IMAGE", len) != 0;
if (!has_image_denotation) {
Expand Down Expand Up @@ -395,7 +399,7 @@ GetTensorType(
static winml::ILearningModelFeatureDescriptor
CreateTensorFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata,
bool has_unsupported_image_metadata) {
auto type_info = feature_descriptor->type_info_.get();
Expand All @@ -412,7 +416,7 @@ CreateTensorFeatureDescriptor(

auto shape = std::vector<int64_t>(num_dims);
if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) {
throw; //TODO fix throw here!;
throw; //TODO fix throw here!;
Copy link
Contributor

@tiagoshibata tiagoshibata Jan 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above


In reply to: 368191885 [](ancestors = 368191885)

}

ONNXTensorElementDataType tensor_element_data_type;
Expand All @@ -424,7 +428,7 @@ CreateTensorFeatureDescriptor(
auto descriptor = winrt::make<winmlp::TensorFeatureDescriptor>(
WinML::Strings::HStringFromUTF8(feature_descriptor->name_),
WinML::Strings::HStringFromUTF8(feature_descriptor->description_), // description
feature_descriptor->name_length_ > 0, // is_required
feature_descriptor->name_length_ > 0, // is_required
kind,
shape,
has_unsupported_image_metadata);
Expand All @@ -435,7 +439,7 @@ CreateTensorFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateImageFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

Expand All @@ -460,7 +464,6 @@ CreateImageFeatureDescriptor(
}
auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type);


// pixel format and alpha
auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key);
auto format_info = CreateBitmapPixelFormatAndAlphaModeInfo(pixel_format_value);
Expand All @@ -472,12 +475,12 @@ CreateImageFeatureDescriptor(
// to TensorFeatureDescriptor (invalid image metadata)
#ifdef DONE_LAYERING
// color space gamma value
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key);
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value);
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key);
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value);

// nominal range
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key);
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value);
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key);
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value);
#endif

// The current code assumes that the shape will be in NCHW.
Expand All @@ -503,25 +506,24 @@ CreateImageFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateMapFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtMapTypeInfo* map_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info),
engine_factory->UseOrtApi());

ONNXTensorElementDataType map_key_data_type;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type),
engine_factory->UseOrtApi());

auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type);

OrtTypeInfo* map_value_type_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info),
engine_factory->UseOrtApi());

UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo);

OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper;
Expand All @@ -534,7 +536,6 @@ CreateMapFeatureDescriptor(
auto value_descriptor =
CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata);


auto descriptor = winrt::make<winmlp::MapFeatureDescriptor>(
WinML::Strings::HStringFromUTF8(feature_descriptor->name_),
WinML::Strings::HStringFromUTF8(feature_descriptor->description_),
Expand All @@ -547,19 +548,18 @@ CreateMapFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateSequenceFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtSequenceTypeInfo* sequence_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info),
engine_factory->UseOrtApi());

OrtTypeInfo* sequence_element_type_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info),
engine_factory->UseOrtApi());

UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo);

OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper;
Expand Down Expand Up @@ -590,6 +590,8 @@ CreateFeatureDescriptor(

ONNXType onnx_type;
engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type);
engine_factory->UseOrtApi();
Copy link
Member Author

@zhangxiang1993 zhangxiang1993 Jan 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miss-added, will be deleted #Resolved


switch (onnx_type) {
case ONNXType::ONNX_TYPE_TENSOR: {
auto tensor_type = GetTensorType(engine_factory, type_info, metadata);
Expand Down
30 changes: 19 additions & 11 deletions winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "OnnxruntimeDmlSessionBuilder.h"
#include "OnnxruntimeEngine.h"
#include "OnnxruntimeErrors.h"
#include "LearningModelDevice.h"

using namespace Windows::AI::MachineLearning;
Expand Down Expand Up @@ -38,15 +39,17 @@ OnnxruntimeDmlSessionBuilder::CreateSessionOptions(
ort_api->DisableMemPattern(session_options.get());

// Request the dml ep
winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get());
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()),
ort_api);

#ifndef _WIN64
auto use_arena = false;
#else
auto use_arena = true;
#endif
winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena);

RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena),
ort_api);

// call release() so the underlying OrtSessionOptions object isn't freed
*options = session_options.release();

Expand All @@ -65,7 +68,8 @@ HRESULT OnnxruntimeDmlSessionBuilder::CreateSession(
RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env));

OrtSession* ort_session_raw;
winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw),
engine_factory_->UseOrtApi());
auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession);

*session = ort_session.release();
Expand All @@ -78,21 +82,25 @@ HRESULT OnnxruntimeDmlSessionBuilder::Initialize(
RETURN_HR_IF_NULL(E_INVALIDARG, session);
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

if (auto status = winml_adapter_api->SessionInitialize(session)) {
return E_FAIL;
}
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionInitialize(session),
engine_factory_->UseOrtApi());

OrtExecutionProvider* ort_provider;
winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider),
engine_factory_->UseOrtApi());


size_t num_providers;
winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->SessionGetExecutionProvidersCount(session, &num_providers),
engine_factory_->UseOrtApi());
RETURN_HR_IF(E_UNEXPECTED, num_providers != 2);

winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true),
engine_factory_->UseOrtApi());

// Flush the D3D12 work from the DML execution provider
winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider);
RETURN_HR_IF_WINMLA_API_FAIL_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider),
engine_factory_->UseOrtApi());

return S_OK;
}
Expand Down
Loading