Skip to content

Commit

Permalink
User/xianz/winml adapter c api (#2869)
Browse files Browse the repository at this point in the history
* wrapper all existing winml adapter apis with API_IMPL to try catch

* Return HR or Throw for WinML adapter APIs if failed

* undo macro wrapper for two places

* Wrap error macros around ort apis, too.
  • Loading branch information
zhangxiang1993 authored Jan 24, 2020
1 parent 8bbf921 commit c714ec2
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 236 deletions.
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,11 @@ ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t in
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) {
API_IMPL_BEGIN
*input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str();
*count = model->UseModelInfo()->input_features_[index]->doc_string().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) {
Expand Down
27 changes: 16 additions & 11 deletions winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "OnnxruntimeCpuSessionBuilder.h"
#include "OnnxruntimeEngine.h"
#include "OnnxruntimeErrors.h"

using namespace Windows::AI::MachineLearning;

Expand All @@ -22,33 +23,36 @@ OnnxruntimeCpuSessionBuilder::CreateSessionOptions(
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

OrtSessionOptions* ort_options;
ort_api->CreateSessionOptions(&ort_options);
RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options),
ort_api);

auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions);

// set the graph optimization level to all (used to be called level 3)
ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL);
RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL),
ort_api);

// Onnxruntime will use half the number of concurrent threads supported on the system
// by default. This causes MLAS to not exercise every logical core.
// We force the thread pool size to be maxxed out to ensure that WinML always
// runs the fastest.
ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency());
RETURN_HR_IF_NOT_OK_MSG(ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency()),
ort_api);

#ifndef _WIN64
auto use_arena = false;
#else
auto use_arena = true;
#endif
winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena);
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena),
ort_api);

// call release() so the underlying OrtSessionOptions object isn't freed
*options = session_options.release();

return S_OK;
}


HRESULT
OnnxruntimeCpuSessionBuilder::CreateSession(
OrtSessionOptions* options,
Expand All @@ -62,11 +66,13 @@ OnnxruntimeCpuSessionBuilder::CreateSession(
RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env));

OrtSession* ort_session_raw;
winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw);
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw),
engine_factory_->UseOrtApi());

auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession);

*session = ort_session.release();

return S_OK;
}

Expand All @@ -76,9 +82,8 @@ OnnxruntimeCpuSessionBuilder::Initialize(
RETURN_HR_IF_NULL(E_INVALIDARG, session);

auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
if (auto status = winml_adapter_api->SessionInitialize(session)) {
return E_FAIL;
}
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session),
engine_factory_->UseOrtApi());

return S_OK;
}
122 changes: 57 additions & 65 deletions winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "OnnxruntimeEngine.h"

#include "OnnxruntimeErrors.h"

using namespace winrt::Windows::AI::MachineLearning;

// BitmapPixelFormat constants
Expand Down Expand Up @@ -44,12 +46,11 @@ static const char* c_supported_nominal_ranges[] =

namespace Windows::AI::MachineLearning {


// Forward declare CreateFeatureDescriptor
static winml::ILearningModelFeatureDescriptor
CreateFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata);

static TensorKind
Expand Down Expand Up @@ -100,7 +101,9 @@ TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: {
return TensorKind::Complex128;
}
default: { return TensorKind::Undefined; }
default: {
return TensorKind::Undefined;
}
}
}

Expand Down Expand Up @@ -153,7 +156,9 @@ TensorKindToString(TensorKind tensorKind) {
return "complex128";
}
case TensorKind::Undefined:
default: { return "undefined"; }
default: {
return "undefined";
}
}
}

Expand Down Expand Up @@ -310,9 +315,8 @@ GetTensorType(
const std::unordered_map<std::string, std::string>& metadata) {
const char* denotation;
size_t len;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len),
engine_factory->UseOrtApi());

constexpr char c_image[] = "IMAGE";
auto has_image_denotation = strncmp(denotation, c_image, _countof(c_image)) == 0;
Expand All @@ -327,14 +331,12 @@ GetTensorType(
// Check if the tensor value_info_proto is of type float.
// IMAGE tensors MUST be of type float
const OrtTensorTypeAndShapeInfo* tensor_info;
if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info),
engine_factory->UseOrtApi());

ONNXTensorElementDataType tensor_element_data_type;
if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type),
engine_factory->UseOrtApi());

auto tensor_kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type);
auto is_float_tensor = tensor_kind == TensorKind::Float;
Expand Down Expand Up @@ -396,36 +398,32 @@ GetTensorType(
static winml::ILearningModelFeatureDescriptor
CreateTensorFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata,
bool has_unsupported_image_metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtTensorTypeAndShapeInfo* tensor_info;
if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) {
throw; //TODO fix throw here!;
}

THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info),
engine_factory->UseOrtApi());
size_t num_dims;
if (auto status = engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims),
engine_factory->UseOrtApi());

auto shape = std::vector<int64_t>(num_dims);
if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()),
engine_factory->UseOrtApi());

ONNXTensorElementDataType tensor_element_data_type;
if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type),
engine_factory->UseOrtApi());

auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type);

auto descriptor = winrt::make<winmlp::TensorFeatureDescriptor>(
WinML::Strings::HStringFromUTF8(feature_descriptor->name_),
WinML::Strings::HStringFromUTF8(feature_descriptor->description_), // description
feature_descriptor->name_length_ > 0, // is_required
feature_descriptor->name_length_ > 0, // is_required
kind,
shape,
has_unsupported_image_metadata);
Expand All @@ -436,32 +434,27 @@ CreateTensorFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateImageFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtTensorTypeAndShapeInfo* tensor_info;
if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info),
engine_factory->UseOrtApi());

size_t num_dims;
if (auto status = engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims),
engine_factory->UseOrtApi());

auto shape = std::vector<int64_t>(num_dims);
if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()),
engine_factory->UseOrtApi());

ONNXTensorElementDataType tensor_element_data_type;
if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type),
engine_factory->UseOrtApi());
auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type);


// pixel format and alpha
auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key);
auto format_info = CreateBitmapPixelFormatAndAlphaModeInfo(pixel_format_value);
Expand All @@ -473,12 +466,12 @@ CreateImageFeatureDescriptor(
// to TensorFeatureDescriptor (invalid image metadata)
#ifdef DONE_LAYERING
// color space gamma value
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key);
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value);
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key);
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value);

// nominal range
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key);
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value);
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key);
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value);
#endif

// The current code assumes that the shape will be in NCHW.
Expand All @@ -504,25 +497,24 @@ CreateImageFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateMapFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtMapTypeInfo* map_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info),
engine_factory->UseOrtApi());

ONNXTensorElementDataType map_key_data_type;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type),
engine_factory->UseOrtApi());

auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type);

OrtTypeInfo* map_value_type_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info),
engine_factory->UseOrtApi());

UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo);

OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper;
Expand All @@ -535,7 +527,6 @@ CreateMapFeatureDescriptor(
auto value_descriptor =
CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata);


auto descriptor = winrt::make<winmlp::MapFeatureDescriptor>(
WinML::Strings::HStringFromUTF8(feature_descriptor->name_),
WinML::Strings::HStringFromUTF8(feature_descriptor->description_),
Expand All @@ -548,19 +539,18 @@ CreateMapFeatureDescriptor(
static winml::ILearningModelFeatureDescriptor
CreateSequenceFeatureDescriptor(
OnnxruntimeEngineFactory* engine_factory,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const OnnxruntimeValueInfoWrapper* feature_descriptor,
const std::unordered_map<std::string, std::string>& metadata) {
auto type_info = feature_descriptor->type_info_.get();

const OrtSequenceTypeInfo* sequence_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info),
engine_factory->UseOrtApi());

OrtTypeInfo* sequence_element_type_info;
if (auto status = engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info)) {
throw; //TODO fix throw here!;
}
THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info),
engine_factory->UseOrtApi());

UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo);

OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper;
Expand Down Expand Up @@ -590,7 +580,9 @@ CreateFeatureDescriptor(
auto type_info = feature_descriptor->type_info_.get();

ONNXType onnx_type;
engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type);
THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type),
engine_factory->UseOrtApi());

switch (onnx_type) {
case ONNXType::ONNX_TYPE_TENSOR: {
auto tensor_type = GetTensorType(engine_factory, type_info, metadata);
Expand Down
Loading

0 comments on commit c714ec2

Please sign in to comment.