diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index d68443d689088..bc667f90636af 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -310,9 +310,11 @@ ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t in } ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, const char** input_description, size_t* count) { + API_IMPL_BEGIN *input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str(); *count = model->UseModelInfo()->input_features_[index]->doc_string().size(); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, const char** output_description, size_t* count) { diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp index ab03d9ed32a5a..9aeedb7613099 100644 --- a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp @@ -5,6 +5,7 @@ #include "OnnxruntimeCpuSessionBuilder.h" #include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" using namespace Windows::AI::MachineLearning; @@ -22,25 +23,29 @@ OnnxruntimeCpuSessionBuilder::CreateSessionOptions( auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtSessionOptions* ort_options; - ort_api->CreateSessionOptions(&ort_options); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); // set the graph optimization level to all (used to be called level 3) - ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL); + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); // Onnxruntime will use half the number of concurrent threads supported on the system // by default. This causes MLAS to not exercise every logical core. // We force the thread pool size to be maxxed out to ensure that WinML always // runs the fastest. - ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency()), + ort_api); #ifndef _WIN64 auto use_arena = false; #else auto use_arena = true; #endif - winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); // call release() so the underlying OrtSessionOptions object isn't freed *options = session_options.release(); @@ -48,7 +53,6 @@ OnnxruntimeCpuSessionBuilder::CreateSessionOptions( return S_OK; } - HRESULT OnnxruntimeCpuSessionBuilder::CreateSession( OrtSessionOptions* options, @@ -62,11 +66,13 @@ OnnxruntimeCpuSessionBuilder::CreateSession( RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); OrtSession* ort_session_raw; - winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); + auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); - + *session = ort_session.release(); - + return S_OK; } @@ -76,9 +82,8 @@ OnnxruntimeCpuSessionBuilder::Initialize( RETURN_HR_IF_NULL(E_INVALIDARG, session); auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - if (auto status = winml_adapter_api->SessionInitialize(session)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); return S_OK; } diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp index cf177eb0e3651..810a8265d6cbf 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp @@ -16,6 +16,8 @@ #include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" + using namespace winrt::Windows::AI::MachineLearning; // BitmapPixelFormat constants @@ -44,12 +46,11 @@ static const char* c_supported_nominal_ranges[] = namespace Windows::AI::MachineLearning { - // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( OnnxruntimeEngineFactory* engine_factory, - const OnnxruntimeValueInfoWrapper* feature_descriptor, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata); static TensorKind @@ -100,7 +101,9 @@ TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) { case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { return TensorKind::Complex128; } - default: { return TensorKind::Undefined; } + default: { + return TensorKind::Undefined; + } } } @@ -153,7 +156,9 @@ TensorKindToString(TensorKind tensorKind) { return "complex128"; } case TensorKind::Undefined: - default: { return "undefined"; } + default: { + return "undefined"; + } } } @@ -310,9 +315,8 @@ GetTensorType( const std::unordered_map& metadata) { const char* denotation; size_t len; - if (auto status = engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len), + engine_factory->UseOrtApi()); constexpr char c_image[] = "IMAGE"; auto has_image_denotation = strncmp(denotation, c_image, _countof(c_image)) == 0; @@ -327,14 +331,12 @@ GetTensorType( // Check if the tensor value_info_proto is of type float. // IMAGE tensors MUST be of type float const OrtTensorTypeAndShapeInfo* tensor_info; - if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); ONNXTensorElementDataType tensor_element_data_type; - if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); auto tensor_kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); auto is_float_tensor = tensor_kind == TensorKind::Float; @@ -396,36 +398,32 @@ GetTensorType( static winml::ILearningModelFeatureDescriptor CreateTensorFeatureDescriptor( OnnxruntimeEngineFactory* engine_factory, - const OnnxruntimeValueInfoWrapper* feature_descriptor, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata, bool has_unsupported_image_metadata) { auto type_info = feature_descriptor->type_info_.get(); const OrtTensorTypeAndShapeInfo* tensor_info; - if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) { - throw; //TODO fix throw here!; - } - + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); size_t num_dims; - if (auto status = engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); auto shape = std::vector(num_dims); - if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); ONNXTensorElementDataType tensor_element_data_type; - if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); auto descriptor = winrt::make( WinML::Strings::HStringFromUTF8(feature_descriptor->name_), WinML::Strings::HStringFromUTF8(feature_descriptor->description_), // description - feature_descriptor->name_length_ > 0, // is_required + feature_descriptor->name_length_ > 0, // is_required kind, shape, has_unsupported_image_metadata); @@ -436,32 +434,27 @@ CreateTensorFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateImageFeatureDescriptor( OnnxruntimeEngineFactory* engine_factory, - const OnnxruntimeValueInfoWrapper* feature_descriptor, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { auto type_info = feature_descriptor->type_info_.get(); const OrtTensorTypeAndShapeInfo* tensor_info; - if (auto status = engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); size_t num_dims; - if (auto status = engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); auto shape = std::vector(num_dims); - if (auto status = engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size())) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); ONNXTensorElementDataType tensor_element_data_type; - if (auto status = engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); - // pixel format and alpha auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key); auto format_info = CreateBitmapPixelFormatAndAlphaModeInfo(pixel_format_value); @@ -473,12 +466,12 @@ CreateImageFeatureDescriptor( // to TensorFeatureDescriptor (invalid image metadata) #ifdef DONE_LAYERING // color space gamma value - auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); - auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); + auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); + auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); // nominal range - auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); - auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); + auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); + auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); #endif // The current code assumes that the shape will be in NCHW. @@ -504,25 +497,24 @@ CreateImageFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateMapFeatureDescriptor( OnnxruntimeEngineFactory* engine_factory, - const OnnxruntimeValueInfoWrapper* feature_descriptor, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { auto type_info = feature_descriptor->type_info_.get(); const OrtMapTypeInfo* map_info; - if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info), + engine_factory->UseOrtApi()); ONNXTensorElementDataType map_key_data_type; - if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type), + engine_factory->UseOrtApi()); + auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type); OrtTypeInfo* map_value_type_info; - if (auto status = engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info), + engine_factory->UseOrtApi()); + UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; @@ -535,7 +527,6 @@ CreateMapFeatureDescriptor( auto value_descriptor = CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); - auto descriptor = winrt::make( WinML::Strings::HStringFromUTF8(feature_descriptor->name_), WinML::Strings::HStringFromUTF8(feature_descriptor->description_), @@ -548,19 +539,18 @@ CreateMapFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateSequenceFeatureDescriptor( OnnxruntimeEngineFactory* engine_factory, - const OnnxruntimeValueInfoWrapper* feature_descriptor, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { auto type_info = feature_descriptor->type_info_.get(); const OrtSequenceTypeInfo* sequence_info; - if (auto status = engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info), + engine_factory->UseOrtApi()); OrtTypeInfo* sequence_element_type_info; - if (auto status = engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info)) { - throw; //TODO fix throw here!; - } + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info), + engine_factory->UseOrtApi()); + UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; @@ -590,7 +580,9 @@ CreateFeatureDescriptor( auto type_info = feature_descriptor->type_info_.get(); ONNXType onnx_type; - engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type); + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type), + engine_factory->UseOrtApi()); + switch (onnx_type) { case ONNXType::ONNX_TYPE_TENSOR: { auto tensor_type = GetTensorType(engine_factory, type_info, metadata); diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp index 446d8657428d5..8f23e6864f73e 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -7,6 +7,7 @@ #include "OnnxruntimeDmlSessionBuilder.h" #include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" #include "LearningModelDevice.h" using namespace Windows::AI::MachineLearning; @@ -27,26 +28,31 @@ OnnxruntimeDmlSessionBuilder::CreateSessionOptions( auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtSessionOptions* ort_options; - ort_api->CreateSessionOptions(&ort_options); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); // set the graph optimization level to all (used to be called level 3) - ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL); + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. - ort_api->DisableMemPattern(session_options.get()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), + ort_api); // Request the dml ep - winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()), + ort_api); #ifndef _WIN64 auto use_arena = false; #else auto use_arena = true; #endif - winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena); - + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); + // call release() so the underlying OrtSessionOptions object isn't freed *options = session_options.release(); @@ -65,7 +71,8 @@ HRESULT OnnxruntimeDmlSessionBuilder::CreateSession( RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); OrtSession* ort_session_raw; - winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); *session = ort_session.release(); @@ -77,18 +84,20 @@ HRESULT OnnxruntimeDmlSessionBuilder::Initialize( OrtSession* session) { RETURN_HR_IF_NULL(E_INVALIDARG, session); auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - - if (auto status = winml_adapter_api->SessionInitialize(session)) { - return E_FAIL; - } + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider), + engine_factory_->UseOrtApi()); - winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true), + engine_factory_->UseOrtApi()); // Flush the D3D12 work from the DML execution provider - winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); return S_OK; } diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 1124705ed3235..53a18485e466c 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -7,6 +7,7 @@ #include "OnnxruntimeEngineBuilder.h" #include "OnnxruntimeModel.h" #include "OnnxruntimeSessionBuilder.h" +#include "OnnxruntimeErrors.h" using namespace WinML; @@ -68,7 +69,9 @@ ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) { case winml::TensorKind::Complex128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } - default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } } } @@ -97,14 +100,17 @@ HRESULT OnnxruntimeValue::IsCpu(bool* out) { auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); OrtMemoryInfo* ort_memory_info; - winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info), + ort_api); auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo); const char* name; - ort_api->MemoryInfoGetName(memory_info.get(), &name); + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name), + ort_api); OrtMemType type; - ort_api->MemoryInfoGetMemType(memory_info.get(), &type); + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type), + ort_api); *out = !strcmp(name, "Cpu") || type == OrtMemType::OrtMemTypeCPUOutput || @@ -129,22 +135,26 @@ static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value, std::vector out; size_t size; - ort_api->GetDimensionsCount(type_and_shape_info, &size); + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info, &size), + ort_api); std::vector shape(size); - ort_api->GetDimensions(type_and_shape_info, &shape[0], size); + THROW_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info, &shape[0], size), + ort_api); auto length = ShapeSize(shape.data(), shape.size()); // make a big buffer to hold all the string data size_t buffer_length; - ort_api->GetStringTensorDataLength(ort_value, &buffer_length); + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorDataLength(ort_value, &buffer_length), + ort_api); std::vector strings; std::unique_ptr buffer(new uint8_t[buffer_length]); std::vector offsets(length); - ort_api->GetStringTensorContent(ort_value, buffer.get(), buffer_length, offsets.data(), offsets.size()); + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorContent(ort_value, buffer.get(), buffer_length, offsets.data(), offsets.size()), + ort_api); // now go build all the strings for (auto i = 0; i < length; ++i) { @@ -177,7 +187,6 @@ WinML::Resource OnnxruntimeValue::GetResource() { winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data, reinterpret_cast(&resource)); return WinML::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); - } else { int is_tensor; ort_api->IsTensor(value_.get(), &is_tensor); @@ -206,7 +215,8 @@ HRESULT OnnxruntimeValue::IsTensor(bool* out) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); ONNXType type = ONNXType::ONNX_TYPE_UNKNOWN; - ort_api->GetValueType(value_.get(), &type); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueType(value_.get(), &type), + ort_api); *out = type == ONNXType::ONNX_TYPE_TENSOR; return S_OK; } @@ -214,11 +224,13 @@ HRESULT OnnxruntimeValue::IsTensor(bool* out) { HRESULT OnnxruntimeValue::IsOfTensorType(winml::TensorKind kind, bool* out) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); OrtTensorTypeAndShapeInfo* info = nullptr; - ort_api->GetTensorTypeAndShape(value_.get(), &info); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); ONNXTensorElementDataType data_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type), + ort_api); *out = data_type == ONNXTensorElementDataTypeFromTensorKind(kind); return S_OK; @@ -227,14 +239,17 @@ HRESULT OnnxruntimeValue::IsOfTensorType(winml::TensorKind kind, bool* out) { HRESULT OnnxruntimeValue::GetTensorShape(std::vector& shape_vector) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); OrtTensorTypeAndShapeInfo* info = nullptr; - ort_api->GetTensorTypeAndShape(value_.get(), &info); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); size_t size; - ort_api->GetDimensionsCount(type_and_shape_info.get(), &size); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info.get(), &size), + ort_api); std::vector shape(size); - ort_api->GetDimensions(type_and_shape_info.get(), &shape[0], size); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info.get(), &shape[0], size), + ort_api); shape_vector = std::move(shape); return S_OK; @@ -245,26 +260,32 @@ static bool EnsureMapTypeInfo(OnnxruntimeEngine* engine, OrtTypeInfo* type_info, auto winml_adapter_api = engine->GetEngineFactory()->UseWinmlAdapterApi(); const OrtMapTypeInfo* map_info; - winml_adapter_api->CastTypeInfoToMapTypeInfo(type_info, &map_info); + THROW_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToMapTypeInfo(type_info, &map_info), + ort_api); ONNXTensorElementDataType map_key_type; - winml_adapter_api->GetMapKeyType(map_info, &map_key_type); + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapKeyType(map_info, &map_key_type), + ort_api); if (map_key_type == ONNXTensorElementDataTypeFromTensorKind(key_kind)) { OrtTypeInfo* value_info; - winml_adapter_api->GetMapValueType(map_info, &value_info); + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapValueType(map_info, &value_info), + ort_api); auto map_value_info = UniqueOrtTypeInfo(value_info, ort_api->ReleaseTypeInfo); const OrtTensorTypeAndShapeInfo* value_tensor_info = nullptr; - ort_api->CastTypeInfoToTensorInfo(map_value_info.get(), &value_tensor_info); + THROW_IF_NOT_OK_MSG(ort_api->CastTypeInfoToTensorInfo(map_value_info.get(), &value_tensor_info), + ort_api); if (value_tensor_info) { ONNXTensorElementDataType map_value_tensor_type; - ort_api->GetTensorElementType(value_tensor_info, &map_value_tensor_type); + THROW_IF_NOT_OK_MSG(ort_api->GetTensorElementType(value_tensor_info, &map_value_tensor_type), + ort_api); if (map_value_tensor_type == ONNXTensorElementDataTypeFromTensorKind(value_kind)) { size_t num_dims; - ort_api->GetDimensionsCount(value_tensor_info, &num_dims); + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(value_tensor_info, &num_dims), + ort_api); return num_dims == 0; } @@ -277,11 +298,13 @@ HRESULT OnnxruntimeValue::IsOfMapType(winml::TensorKind key_kind, winml::TensorK auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); OrtTypeInfo* info = nullptr; - ort_api->GetTypeInfo(value_.get(), &info); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); ONNXType type; - ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); if (type == ONNXType::ONNX_TYPE_MAP) { *out = EnsureMapTypeInfo(engine_.Get(), unique_type_info.get(), key_kind, value_kind); @@ -297,18 +320,22 @@ HRESULT OnnxruntimeValue::IsOfVectorMapType(winml::TensorKind key_kind, winml::T auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); OrtTypeInfo* info = nullptr; - ort_api->GetTypeInfo(value_.get(), &info); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); ONNXType type; - ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); if (type == ONNXType::ONNX_TYPE_SEQUENCE) { const OrtSequenceTypeInfo* sequence_info; - winml_adapter_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info), + ort_api); OrtTypeInfo* element_info; - winml_adapter_api->GetSequenceElementType(sequence_info, &element_info); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetSequenceElementType(sequence_info, &element_info), + ort_api); auto unique_element_info = UniqueOrtTypeInfo(element_info, ort_api->ReleaseTypeInfo); *out = EnsureMapTypeInfo(engine_.Get(), unique_element_info.get(), key_kind, value_kind); @@ -351,8 +378,8 @@ HRESULT OnnxruntimeEngine::LoadModel(_In_ IModel* model) { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - winml_adapter_api->SessionLoadAndPurloinModel(session_.get(), ort_model); - + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session_.get(), ort_model), + engine_factory_->UseOrtApi()); return S_OK; } @@ -363,19 +390,22 @@ HRESULT OnnxruntimeEngine::Initialize() { HRESULT OnnxruntimeEngine::RegisterGraphTransformers() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - winml_adapter_api->SessionRegisterGraphTransformers(session_.get()); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterGraphTransformers(session_.get()), + engine_factory_->UseOrtApi()); return S_OK; } HRESULT OnnxruntimeEngine::RegisterCustomRegistry(IMLOperatorRegistry* registry) { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - winml_adapter_api->SessionRegisterCustomRegistry(session_.get(), registry); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session_.get(), registry), + engine_factory_->UseOrtApi()); return S_OK; } HRESULT OnnxruntimeEngine::EndProfiling() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - winml_adapter_api->SessionEndProfiling(session_.get()); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionEndProfiling(session_.get()), + engine_factory_->UseOrtApi()); return S_OK; } @@ -385,7 +415,8 @@ HRESULT OnnxruntimeEngine::StartProfiling() { OrtEnv* ort_env; engine_factory_->GetOrtEnvironment(&ort_env); - winml_adapter_api->SessionStartProfiling(ort_env, session_.get()); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionStartProfiling(ort_env, session_.get()), + engine_factory_->UseOrtApi()); return S_OK; } @@ -393,9 +424,11 @@ HRESULT OnnxruntimeEngine::FlushContext() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); - winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); return S_OK; } @@ -403,9 +436,12 @@ HRESULT OnnxruntimeEngine::TrimUploadHeap() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider), + engine_factory_->UseOrtApi()); - winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider); return S_OK; } @@ -413,9 +449,12 @@ HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider), + engine_factory_->UseOrtApi()); - winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider); return S_OK; } @@ -423,7 +462,8 @@ HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); auto src_value = static_cast(src); auto dest_value = static_cast(dest); @@ -435,7 +475,9 @@ HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) { auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty); RETURN_HR_IF(E_FAIL, has_null_dest); - winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()), + engine_factory_->UseOrtApi()); + return S_OK; } @@ -443,9 +485,12 @@ HRESULT OnnxruntimeEngine::Sync() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider), + engine_factory_->UseOrtApi()); - winml_adapter_api->ExecutionProviderSync(ort_provider); return S_OK; } @@ -466,14 +511,18 @@ HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count, auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); OrtAllocator* ort_allocator; - winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator), + engine_factory_->UseOrtApi()); + auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything OrtValue* ort_value; - ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value), + ort_api); auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), std::move(unique_allocator))); return S_OK; @@ -500,13 +549,17 @@ HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resour auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtExecutionProvider* ort_provider; - winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); OrtMemoryInfo* dml_memory = nullptr; - winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory), + engine_factory_->UseOrtApi()); void* dml_allocator_resource; - winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), + engine_factory_->UseOrtApi()); + auto unique_dml_allocator_resource = DmlAllocatorResource(dml_allocator_resource, [](void* ptr) { @@ -515,14 +568,15 @@ HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resour // create the OrtValue as a tensor letting ort know that we own the data buffer OrtValue* ort_value; - ort_api->CreateTensorWithDataAsOrtValue( - dml_memory, - unique_dml_allocator_resource.get(), - d3d_resource->GetDesc().Width, - shape, - count, - ONNXTensorElementDataTypeFromTensorKind(kind), - &ort_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + dml_memory, + unique_dml_allocator_resource.get(), + d3d_resource->GetDesc().Width, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); Microsoft::WRL::ComPtr out_value; @@ -544,7 +598,8 @@ HRESULT OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy(const char* c RETURN_IF_FAILED(CreateTensorValue(shape, count, winml::TensorKind::String, out)); auto ort_value = reinterpret_cast(*out)->UseOrtValue(); - ort_api->FillStringTensor(ort_value, reinterpret_cast(data), num_elements); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(ort_value, reinterpret_cast(data), num_elements), + ort_api); return S_OK; } @@ -558,17 +613,19 @@ HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalBuffer(void* data, size_ // TODO: what is the difference between the device allocator and the arena allocator? OrtMemoryInfo* cpu_memory; - ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory), + ort_api); OrtValue* ort_value; - ort_api->CreateTensorWithDataAsOrtValue( - cpu_memory, - data, - size_in_bytes, - shape, - count, - ONNXTensorElementDataTypeFromTensorKind(kind), - &ort_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + cpu_memory, + data, + size_in_bytes, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); @@ -646,10 +703,12 @@ template struct FillMapTensors { static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { AbiTypeInfo::OrtType* keys_mutable_data; - ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); AbiTypeInfo::OrtType* values_mutable_data; - ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); auto map = CastToWinrtMap(map_insp); size_t index = 0; @@ -666,7 +725,8 @@ template struct FillMapTensors { static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { AbiTypeInfo::OrtType* values_mutable_data; - ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); auto map = CastToWinrtMap(map_insp); size_t index = 0; @@ -684,7 +744,8 @@ struct FillMapTensors { std::back_inserter(raw_values), [&](auto& str) { return str.c_str(); }); - ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); return S_OK; } @@ -694,7 +755,8 @@ template struct FillMapTensors { static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { AbiTypeInfo::OrtType* keys_mutable_data; - ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); auto map = CastToWinrtMap(map_insp); size_t index = 0; @@ -712,7 +774,8 @@ struct FillMapTensors { std::back_inserter(raw_values), [&](auto& str) { return str.c_str(); }); - ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); return S_OK; } }; @@ -744,8 +807,10 @@ struct FillMapTensors { std::back_inserter(raw_values), [&](auto& str) { return str.c_str(); }); - ort_api->FillStringTensor(keys_ort_value, raw_keys.data(), raw_keys.size()); - ort_api->FillStringTensor(values_ort_value, raw_values.data(), raw_values.size()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_keys.data(), raw_keys.size()), + ort_api); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(values_ort_value, raw_values.data(), raw_values.size()), + ort_api); return S_OK; } }; @@ -770,7 +835,8 @@ HRESULT CreateMapValue(OnnxruntimeEngine* engine, IInspectable* map_insp, winml: OrtValue* inputs[2] = {keys_ort_value, values_ort_value}; OrtValue* map_value; - ort_api->CreateValue(inputs, 2, ONNXType::ONNX_TYPE_MAP, &map_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateValue(inputs, 2, ONNXType::ONNX_TYPE_MAP, &map_value), + ort_api); auto unique_map_ort_value = UniqueOrtValue(map_value, ort_api->ReleaseValue); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, engine, std::move(unique_map_ort_value), UniqueOrtAllocator(nullptr, nullptr))); @@ -947,7 +1013,8 @@ HRESULT OnnxruntimeEngine::Run(const char** input_names, IValue** inputs, size_t auto ort_api = engine_factory_->UseOrtApi(); OrtRunOptions* run_options; - ort_api->CreateRunOptions(&run_options); + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateRunOptions(&run_options), + ort_api); auto unique_run_options = UniqueOrtRunOptions(run_options, ort_api->ReleaseRunOptions); std::vector input_ort_values; @@ -970,14 +1037,15 @@ HRESULT OnnxruntimeEngine::Run(const char** input_names, IValue** inputs, size_t return output_value->UseOrtValue(); }); - ort_api->Run(session_.get(), - unique_run_options.get(), - input_names, - input_ort_values.data(), - num_inputs, - output_names, - num_outputs, - output_ort_values.data()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->Run(session_.get(), + unique_run_options.get(), + input_names, + input_ort_values.data(), + num_inputs, + output_names, + num_outputs, + output_ort_values.data()), + ort_api); for (size_t index = 0; index < num_outputs; index++) { auto output_value = static_cast(outputs[index]); @@ -1033,18 +1101,21 @@ HRESULT OnnxruntimeEngine::FillFromMapValue(IInspectable* map, winml::TensorKind auto ort_map_value = onnxruntime_map_value->UseOrtValue(); OrtAllocator* ort_allocator; - ort_api->GetAllocatorWithDefaultOptions(&ort_allocator); // This should not be freed as this owned by ort + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), + ort_api); // This should not be freed as this owned by ort // get the keys OrtValue* keys_ort_value = nullptr; - ort_api->GetValue(ort_map_value, 0, ort_allocator, &keys_ort_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 0, ort_allocator, &keys_ort_value), + ort_api); auto unique_keys_value = UniqueOrtValue(keys_ort_value, ort_api->ReleaseValue); winrt::com_ptr keys_value; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(keys_value.put(), this, std::move(unique_keys_value), UniqueOrtAllocator(nullptr, nullptr))); // get the keys OrtValue* values_ort_value = nullptr; - ort_api->GetValue(ort_map_value, 1, ort_allocator, &values_ort_value); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 1, ort_allocator, &values_ort_value), + ort_api); auto unique_values_value = UniqueOrtValue(values_ort_value, ort_api->ReleaseValue); winrt::com_ptr values_value; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(values_value.put(), this, std::move(unique_values_value), UniqueOrtAllocator(nullptr, nullptr))); @@ -1074,9 +1145,8 @@ HRESULT OnnxruntimeEngineFactory::RuntimeClassInitialize() { STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) { OrtModel* ort_model = nullptr; - if (auto status = winml_adapter_api_->CreateModelFromPath(model_path, len, &ort_model)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateModelFromPath(model_path, len, &ort_model), + ort_api_); auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); @@ -1085,9 +1155,8 @@ STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ const char* model_path, STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) { OrtModel* ort_model = nullptr; - if (auto status = winml_adapter_api_->CreateModelFromData(data, size, &ort_model)) { - return E_INVALIDARG; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateModelFromData(data, size, &ort_model), + ort_api_); auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); @@ -1120,7 +1189,8 @@ HRESULT OnnxruntimeEngineFactory::EnableDebugOutput(bool is_enabled) { } HRESULT OnnxruntimeEngineFactory::CreateCustomRegistry(IMLOperatorRegistry** registry) { - winml_adapter_api_->CreateCustomRegistry(registry); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateCustomRegistry(registry), + ort_api_); return S_OK; } diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp index 2a268964b42c3..f70624e4cc30a 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp @@ -8,6 +8,7 @@ #include "OnnxruntimeDmlSessionBuilder.h" #endif +#include "OnnxruntimeErrors.h" using namespace WinML; HRESULT OnnxruntimeEngineBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { @@ -31,10 +32,11 @@ STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(Windows::AI::MachineLearning OrtSessionOptions* ort_options; RETURN_IF_FAILED(onnxruntime_session_builder->CreateSessionOptions(&ort_options)); auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); - + if (batch_size_override_.has_value()) { constexpr const char* DATA_BATCH = "DATA_BATCH"; - ort_api->AddFreeDimensionOverride(session_options.get(), DATA_BATCH, batch_size_override_.value()); + RETURN_HR_IF_NOT_OK_MSG(ort_api->AddFreeDimensionOverride(session_options.get(), DATA_BATCH, batch_size_override_.value()), + ort_api); } OrtSession* ort_session = nullptr; @@ -43,7 +45,7 @@ STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(Windows::AI::MachineLearning Microsoft::WRL::ComPtr onnxruntime_engine; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine, - engine_factory_.Get(), std::move(session), onnxruntime_session_builder.Get())); + engine_factory_.Get(), std::move(session), onnxruntime_session_builder.Get())); RETURN_IF_FAILED(onnxruntime_engine.CopyTo(out)); return S_OK; } diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index 80dcca6c16208..700e437b7fe67 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -3,6 +3,7 @@ #include "pch.h" #include "OnnxruntimeEnvironment.h" +#include "OnnxruntimeErrors.h" #include "core/platform/windows/TraceLoggingConfig.h" #include @@ -122,9 +123,8 @@ static void WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_ OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) { OrtEnv* ort_env = nullptr; - if (auto status = ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env)) { - throw; - } + THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), + ort_api); ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv); // Configure the environment with the winml logger @@ -132,11 +132,9 @@ OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_ auto status = winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(), &WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr, OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env); - if (status) { - throw; - } + THROW_IF_NOT_OK_MSG(status, ort_api); - winml_adapter_api->OverrideSchema(); + THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api); } HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) { diff --git a/winml/lib/Api.Ort/OnnxruntimeErrors.h b/winml/lib/Api.Ort/OnnxruntimeErrors.h new file mode 100644 index 0000000000000..3f9fd88b783d6 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeErrors.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "pch.h" +#include "core/providers/winml/winml_provider_factory.h" + +#ifdef _WIN32 +inline HRESULT OrtErrorCodeToHRESULT(OrtErrorCode status) noexcept { + switch (status) { + case OrtErrorCode::ORT_OK: + return S_OK; + case OrtErrorCode::ORT_FAIL: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_ARGUMENT: + return E_INVALIDARG; + case OrtErrorCode::ORT_NO_SUCHFILE: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_NO_MODEL: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_ENGINE_ERROR: + return E_FAIL; + case OrtErrorCode::ORT_RUNTIME_EXCEPTION: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_PROTOBUF: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_MODEL_LOADED: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case OrtErrorCode::ORT_NOT_IMPLEMENTED: + return E_NOTIMPL; + case OrtErrorCode::ORT_INVALID_GRAPH: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_EP_FAIL: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +#define RETURN_HR_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + RETURN_HR_MSG(hresult, \ + error_message); \ + } \ + } while (0) + +#define THROW_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + winrt::hstring errorMessage(WinML::Strings::HStringFromUTF8(error_message)); \ + throw winrt::hresult_error(hresult, errorMessage); \ + } \ + } while (0) diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp index e112db66112d9..bc782fbd17343 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -8,6 +8,7 @@ #include "OnnxruntimeDescriptorConverter.h" #include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" using namespace Windows::AI::MachineLearning; @@ -25,25 +26,22 @@ HRESULT CreateFeatureDescriptors( std::vector& descriptors) { const auto ort_api = engine_factory->UseOrtApi(); size_t count; - if (auto status = feature_helpers->GetCount(ort_model, &count)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetCount(ort_model, &count), + engine_factory->UseOrtApi()); for (size_t i = 0; i < count; i++) { OnnxruntimeValueInfoWrapper descriptor; - if (auto status = feature_helpers->GetName(ort_model, i, &descriptor.name_, &descriptor.name_length_)) { - return E_FAIL; - } - if (auto status = feature_helpers->GetDescription(ort_model, i, &descriptor.description_, &descriptor.description_length_)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetName(ort_model, i, &descriptor.name_, &descriptor.name_length_), + engine_factory->UseOrtApi()); + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetDescription(ort_model, i, &descriptor.description_, &descriptor.description_length_), + engine_factory->UseOrtApi()); OrtTypeInfo* type_info; - if (auto status = feature_helpers->GetTypeInfo(ort_model, i, &type_info)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetTypeInfo(ort_model, i, &type_info), + engine_factory->UseOrtApi()); + descriptor.type_info_ = UniqueOrtTypeInfo(type_info, ort_api->ReleaseTypeInfo); - + descriptors.push_back(std::move(descriptor)); } return S_OK; @@ -56,18 +54,16 @@ HRESULT ModelInfo::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_facto // Get Metadata size_t count; - if (auto status = winml_adapter_api->ModelGetMetadataCount(ort_model, &count)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadataCount(ort_model, &count), + engine_factory->UseOrtApi()); const char* metadata_key; size_t metadata_key_len; const char* metadata_value; size_t metadata_value_len; for (size_t i = 0; i < count; i++) { - if (auto status = winml_adapter_api->ModelGetMetadata(ort_model, i, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadata(ort_model, i, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len), + engine_factory->UseOrtApi()); model_metadata_.insert_or_assign( std::string(metadata_key, metadata_key_len), @@ -101,29 +97,24 @@ HRESULT ModelInfo::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_facto const char* out; size_t len; - if (auto status = winml_adapter_api->ModelGetAuthor(ort_model, &out, &len)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetAuthor(ort_model, &out, &len), + engine_factory->UseOrtApi()); author_ = std::string(out, len); - if (auto status = winml_adapter_api->ModelGetName(ort_model, &out, &len)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetName(ort_model, &out, &len), + engine_factory->UseOrtApi()); name_ = std::string(out, len); - if (auto status = winml_adapter_api->ModelGetDomain(ort_model, &out, &len)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDomain(ort_model, &out, &len), + engine_factory->UseOrtApi()); domain_ = std::string(out, len); - if (auto status = winml_adapter_api->ModelGetDescription(ort_model, &out, &len)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDescription(ort_model, &out, &len), + engine_factory->UseOrtApi()); description_ = std::string(out, len); - if (auto status = winml_adapter_api->ModelGetVersion(ort_model, &version_)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetVersion(ort_model, &version_), + engine_factory->UseOrtApi()); return S_OK; } @@ -205,9 +196,8 @@ STDMETHODIMP OnnruntimeModel::GetModelInfo(IModelInfo** info) { STDMETHODIMP OnnruntimeModel::ModelEnsureNoFloat16() { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); - if (auto status = winml_adapter_api->ModelEnsureNoFloat16(ort_model_.get())) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelEnsureNoFloat16(ort_model_.get()), + engine_factory_->UseOrtApi()); return S_OK; } @@ -215,9 +205,8 @@ STDMETHODIMP OnnruntimeModel::CloneModel(IModel** copy) { auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); OrtModel* ort_model_copy; - if (auto status = winml_adapter_api->CloneModel(ort_model_.get(), &ort_model_copy)) { - return E_FAIL; - } + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CloneModel(ort_model_.get(), &ort_model_copy), + engine_factory_->UseOrtApi()); auto model = UniqueOrtModel(ort_model_copy, winml_adapter_api->ReleaseModel); RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(copy, engine_factory_.Get(), std::move(model))); @@ -225,7 +214,6 @@ STDMETHODIMP OnnruntimeModel::CloneModel(IModel** copy) { return S_OK; } - STDMETHODIMP OnnruntimeModel::DetachOrtModel(OrtModel** model) { *model = ort_model_.release(); return S_OK;