-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
User/xianz/winml adapter c api #2869
Changes from 4 commits
d8b4fac
ef9a412
044c861
e519d73
ddac824
0705e68
c39e8aa
fa65018
d0b89aa
f06434b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
|
||
#include "OnnxruntimeEngine.h" | ||
|
||
#include "OnnxruntimeErrors.h" | ||
|
||
using namespace winrt::Windows::AI::MachineLearning; | ||
|
||
// BitmapPixelFormat constants | ||
|
@@ -44,12 +46,11 @@ static const char* c_supported_nominal_ranges[] = | |
|
||
namespace Windows::AI::MachineLearning { | ||
|
||
|
||
// Forward declare CreateFeatureDescriptor | ||
static winml::ILearningModelFeatureDescriptor | ||
CreateFeatureDescriptor( | ||
OnnxruntimeEngineFactory* engine_factory, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const std::unordered_map<std::string, std::string>& metadata); | ||
|
||
static TensorKind | ||
|
@@ -100,7 +101,9 @@ TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) { | |
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { | ||
return TensorKind::Complex128; | ||
} | ||
default: { return TensorKind::Undefined; } | ||
default: { | ||
return TensorKind::Undefined; | ||
} | ||
} | ||
} | ||
|
||
|
@@ -153,7 +156,9 @@ TensorKindToString(TensorKind tensorKind) { | |
return "complex128"; | ||
} | ||
case TensorKind::Undefined: | ||
default: { return "undefined"; } | ||
default: { | ||
return "undefined"; | ||
} | ||
} | ||
} | ||
|
||
|
@@ -310,9 +315,8 @@ GetTensorType( | |
const std::unordered_map<std::string, std::string>& metadata) { | ||
const char* denotation; | ||
size_t len; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len), | ||
engine_factory->UseOrtApi()); | ||
|
||
auto has_image_denotation = strncmp(denotation, "IMAGE", len) != 0; | ||
if (!has_image_denotation) { | ||
|
@@ -395,7 +399,7 @@ GetTensorType( | |
static winml::ILearningModelFeatureDescriptor | ||
CreateTensorFeatureDescriptor( | ||
OnnxruntimeEngineFactory* engine_factory, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const std::unordered_map<std::string, std::string>& metadata, | ||
bool has_unsupported_image_metadata) { | ||
auto type_info = feature_descriptor->type_info_.get(); | ||
|
@@ -412,7 +416,7 @@ CreateTensorFeatureDescriptor( | |
|
||
auto shape = std::vector<int64_t>(num_dims); | ||
if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) { | ||
throw; //TODO fix throw here!; | ||
throw; //TODO fix throw here!; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
|
||
ONNXTensorElementDataType tensor_element_data_type; | ||
|
@@ -424,7 +428,7 @@ CreateTensorFeatureDescriptor( | |
auto descriptor = winrt::make<winmlp::TensorFeatureDescriptor>( | ||
WinML::Strings::HStringFromUTF8(feature_descriptor->name_), | ||
WinML::Strings::HStringFromUTF8(feature_descriptor->description_), // description | ||
feature_descriptor->name_length_ > 0, // is_required | ||
feature_descriptor->name_length_ > 0, // is_required | ||
kind, | ||
shape, | ||
has_unsupported_image_metadata); | ||
|
@@ -435,7 +439,7 @@ CreateTensorFeatureDescriptor( | |
static winml::ILearningModelFeatureDescriptor | ||
CreateImageFeatureDescriptor( | ||
OnnxruntimeEngineFactory* engine_factory, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const std::unordered_map<std::string, std::string>& metadata) { | ||
auto type_info = feature_descriptor->type_info_.get(); | ||
|
||
|
@@ -460,7 +464,6 @@ CreateImageFeatureDescriptor( | |
} | ||
auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); | ||
|
||
|
||
// pixel format and alpha | ||
auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key); | ||
auto format_info = CreateBitmapPixelFormatAndAlphaModeInfo(pixel_format_value); | ||
|
@@ -472,12 +475,12 @@ CreateImageFeatureDescriptor( | |
// to TensorFeatureDescriptor (invalid image metadata) | ||
#ifdef DONE_LAYERING | ||
// color space gamma value | ||
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); | ||
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); | ||
auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); | ||
auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); | ||
|
||
// nominal range | ||
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); | ||
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); | ||
auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); | ||
auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); | ||
#endif | ||
|
||
// The current code assumes that the shape will be in NCHW. | ||
|
@@ -503,25 +506,24 @@ CreateImageFeatureDescriptor( | |
static winml::ILearningModelFeatureDescriptor | ||
CreateMapFeatureDescriptor( | ||
OnnxruntimeEngineFactory* engine_factory, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const std::unordered_map<std::string, std::string>& metadata) { | ||
auto type_info = feature_descriptor->type_info_.get(); | ||
|
||
const OrtMapTypeInfo* map_info; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info), | ||
engine_factory->UseOrtApi()); | ||
|
||
ONNXTensorElementDataType map_key_data_type; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type), | ||
engine_factory->UseOrtApi()); | ||
|
||
auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type); | ||
|
||
OrtTypeInfo* map_value_type_info; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info), | ||
engine_factory->UseOrtApi()); | ||
|
||
UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); | ||
|
||
OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; | ||
|
@@ -534,7 +536,6 @@ CreateMapFeatureDescriptor( | |
auto value_descriptor = | ||
CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); | ||
|
||
|
||
auto descriptor = winrt::make<winmlp::MapFeatureDescriptor>( | ||
WinML::Strings::HStringFromUTF8(feature_descriptor->name_), | ||
WinML::Strings::HStringFromUTF8(feature_descriptor->description_), | ||
|
@@ -547,19 +548,18 @@ CreateMapFeatureDescriptor( | |
static winml::ILearningModelFeatureDescriptor | ||
CreateSequenceFeatureDescriptor( | ||
OnnxruntimeEngineFactory* engine_factory, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const OnnxruntimeValueInfoWrapper* feature_descriptor, | ||
const std::unordered_map<std::string, std::string>& metadata) { | ||
auto type_info = feature_descriptor->type_info_.get(); | ||
|
||
const OrtSequenceTypeInfo* sequence_info; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info), | ||
engine_factory->UseOrtApi()); | ||
|
||
OrtTypeInfo* sequence_element_type_info; | ||
if (auto status = engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info)) { | ||
throw; //TODO fix throw here!; | ||
} | ||
THROW_IF_WINMLA_API_FAIL_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info), | ||
engine_factory->UseOrtApi()); | ||
|
||
UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); | ||
|
||
OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; | ||
|
@@ -590,6 +590,8 @@ CreateFeatureDescriptor( | |
|
||
ONNXType onnx_type; | ||
engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type); | ||
engine_factory->UseOrtApi(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. miss-added, will be deleted #Resolved |
||
|
||
switch (onnx_type) { | ||
case ONNXType::ONNX_TYPE_TENSOR: { | ||
auto tensor_type = GetTensorType(engine_factory, type_info, metadata); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did clang format for these files. #Resolved