Skip to content

Commit

Permalink
wrapper all existing winml adapter apis with API_IMPL to try catch (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiang1993 authored Jan 17, 2020
1 parent fd69fd8 commit 18412d3
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 1 deletion.
7 changes: 7 additions & 0 deletions onnxruntime/core/framework/onnxruntime_typeinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/sparse_tensor.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/ort_apis.h"
#include "core/framework/error_code_helper.h"

#include "core/framework/tensor_type_and_shape.h"
#include "../../winml/adapter/winml_adapter_map_type_info.h"
Expand Down Expand Up @@ -61,19 +62,25 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType
}

ORT_API_STATUS_IMPL(winmla::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) {
API_IMPL_BEGIN
*out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info_ : nullptr;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) {
API_IMPL_BEGIN
*out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info_ : nullptr;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) {
API_IMPL_BEGIN
*out = type_info->denotation_.c_str();
*len = type_info->denotation_.size();
return nullptr;
API_IMPL_END
}

ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) {
Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_dml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProvi

ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
ID3D12Device* d3d_device, ID3D12CommandQueue* queue) {
API_IMPL_BEGIN
auto dml_device = CreateDmlDevice(d3d_device);
if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) {
return status;
Expand All @@ -61,6 +62,7 @@ ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_
// So we create the provider with rounding disabled, and expect the caller to enable it after.
onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled);
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) {
Expand Down
6 changes: 5 additions & 1 deletion winml/adapter/winml_adapter_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class WinmlAdapterLoggingWrapper : public LoggingWrapper {
ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
std::string name = logid;
std::unique_ptr<onnxruntime::logging::ISink> logger = onnxruntime::make_unique<WinmlAdapterLoggingWrapper>(logging_function, profiling_function, logger_param);

Expand All @@ -64,6 +65,7 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en
// Set a new default logging manager
env->SetLoggingManager(std::move(winml_logging_manager));
return nullptr;
API_IMPL_END
}

// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
Expand All @@ -72,11 +74,13 @@ ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* en
// registered schema are reachable only after upstream schema have been revised in a later OS release,
// which would be a compatibility risk.
ORT_API_STATUS_IMPL(winmla::OverrideSchema) {
API_IMPL_BEGIN
#ifdef USE_DML
static std::once_flag schema_override_once_flag;
std::call_once(schema_override_once_flag, []() {
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
});
#endif USE_DML.
return nullptr;
#endif USE_DML
API_IMPL_END
}
4 changes: 4 additions & 0 deletions winml/adapter/winml_adapter_map_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) {

// OrtMapTypeInfo Accessors
ORT_API_STATUS_IMPL(winmla::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) {
API_IMPL_BEGIN
*out = map_type_info->map_key_type_;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) {
API_IMPL_BEGIN
return map_type_info->map_value_type_->Clone(out);
API_IMPL_END
}

ORT_API(void, winmla::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) {
Expand Down
36 changes: 36 additions & 0 deletions winml/adapter/winml_adapter_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,90 +195,118 @@ std::unique_ptr<onnx::ModelProto> OrtModel::DetachModelProto() {
}

ORT_API_STATUS_IMPL(winmla::CreateModelFromPath, const char* model_path, size_t size, OrtModel** out) {
API_IMPL_BEGIN
if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CreateModelFromData, void* data, size_t size, OrtModel** out) {
API_IMPL_BEGIN
if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::CloneModel, const OrtModel* in, OrtModel** out) {
API_IMPL_BEGIN
auto model_proto_copy = std::make_unique<onnx::ModelProto>(*in->UseModelProto());
if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, const OrtModel* model, const char** const author, size_t* len) {
API_IMPL_BEGIN
*author = model->UseModelInfo()->author_.c_str();
*len = model->UseModelInfo()->author_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetName, const OrtModel* model, const char** const name, size_t* len) {
API_IMPL_BEGIN
*name = model->UseModelInfo()->name_.c_str();
*len = model->UseModelInfo()->name_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetDomain, const OrtModel* model, const char** const domain, size_t* len) {
API_IMPL_BEGIN
*domain = model->UseModelInfo()->domain_.c_str();
*len = model->UseModelInfo()->domain_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetDescription, const OrtModel* model, const char** const description, size_t* len) {
API_IMPL_BEGIN
*description = model->UseModelInfo()->description_.c_str();
*len = model->UseModelInfo()->description_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetVersion, const OrtModel* model, int64_t* version) {
API_IMPL_BEGIN
*version = model->UseModelInfo()->version_;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->model_metadata_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetMetadata, const OrtModel* model, size_t count, const char** const key,
size_t* key_len, const char** const value, size_t* value_len) {
API_IMPL_BEGIN
*key = model->UseModelInfo()->model_metadata_[count].first.c_str();
*key_len = model->UseModelInfo()->model_metadata_[count].first.size();
*value = model->UseModelInfo()->model_metadata_[count].second.c_str();
*value_len = model->UseModelInfo()->model_metadata_[count].second.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->input_features_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, const OrtModel* model, size_t* count) {
API_IMPL_BEGIN
*count = model->UseModelInfo()->output_features_.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputName, const OrtModel* model, size_t index, const char** input_name, size_t* count) {
API_IMPL_BEGIN
*input_name = model->UseModelInfo()->input_features_[index]->name().c_str();
*count = model->UseModelInfo()->input_features_[index]->name().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t index, const char** output_name, size_t* count) {
API_IMPL_BEGIN
*output_name = model->UseModelInfo()->output_features_[index]->name().c_str();
*count = model->UseModelInfo()->output_features_[index]->name().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) {
Expand All @@ -288,26 +316,33 @@ ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, siz
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) {
API_IMPL_BEGIN
*output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str();
*count = model->UseModelInfo()->output_features_[index]->doc_string().size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) {
API_IMPL_BEGIN
if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) {
API_IMPL_BEGIN
if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) {
return status;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) {
API_IMPL_BEGIN
auto model_info = model->UseModelInfo();
auto model_proto = model->UseModelProto();
auto& graph = model_proto->graph();
Expand Down Expand Up @@ -372,6 +407,7 @@ ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) {
}
}
return nullptr;
API_IMPL_END
}

ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) {
Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_sequence_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) {
}

ORT_API_STATUS_IMPL(winmla::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) {
API_IMPL_BEGIN
return sequence_type_info->sequence_key_type_->Clone(out);
API_IMPL_END
}

ORT_API(void, winmla::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) {
Expand Down

0 comments on commit 18412d3

Please sign in to comment.