diff --git a/cmake/winml.cmake b/cmake/winml.cmake index e7a4c6374efa5..dba8d02b12805 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -8,12 +8,13 @@ include(winml_cppwinrt.cmake) # get the current nuget sdk kit directory get_sdk(sdk_folder sdk_version) set(target_folder ONNXRuntime/winml) +set(winml_adapter_dir ${REPO_ROOT}/winml/adapter) set(winml_api_root ${REPO_ROOT}/winml/api) set(winml_dll_dir ${REPO_ROOT}/winml/dll) set(winml_lib_dir ${REPO_ROOT}/winml/lib) set(winml_lib_api_dir ${REPO_ROOT}/winml/lib/api) -set(winml_adapter_dir ${REPO_ROOT}/winml/adapter) set(winml_lib_api_image_dir ${REPO_ROOT}/winml/lib/api.image) +set(winml_lib_api_ort_dir ${REPO_ROOT}/winml/lib/api.ort) set(winml_lib_common_dir ${REPO_ROOT}/winml/lib/common) set(winml_lib_telemetry_dir ${REPO_ROOT}/winml/lib/telemetry) @@ -116,32 +117,102 @@ set_target_properties(winml_lib_telemetry # Link libraries target_link_libraries(winml_lib_telemetry PRIVATE wil) +########################### +# Add winml_lib_ort +########################### + +list(APPEND winml_lib_api_ort_files + ${winml_lib_api_ort_dir}/inc/OnnxruntimeProvider.h + ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.h + ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEngine.h + ${winml_lib_api_ort_dir}/OnnxruntimeEngine.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.h + ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeModel.h + ${winml_lib_api_ort_dir}/OnnxruntimeModel.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeSessionBuilder.h + ${winml_lib_api_ort_dir}/pch.h + ) + +if (onnxruntime_USE_DML) + list(APPEND winml_lib_api_ort_files + ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.cpp + ) +endif(onnxruntime_USE_DML) + +# Add static library that will be archived/linked for both static/dynamic library +add_library(winml_lib_ort STATIC ${winml_lib_api_ort_files}) + +# Compiler options +target_compile_features(winml_lib_ort PRIVATE cxx_std_17) +target_compile_options(winml_lib_ort PRIVATE /GR- /await /wd4238) + +# Compiler definitions +target_compile_definitions(winml_lib_ort PRIVATE PLATFORM_WINDOWS) +target_compile_definitions(winml_lib_ort PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators + +# Specify the usage of a precompiled header +target_precompiled_header(winml_lib_ort pch.h) + +# Includes +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers + +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + +target_include_directories(winml_lib_ort PRIVATE ${REPO_ROOT}/winml) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_dir}) # needed for generated headers +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_core_dir}) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_ort_dir}) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_common_dir}/inc) +target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}) +target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_ROOT}) + +set_target_properties(winml_lib_ort + PROPERTIES + FOLDER + ${target_folder}) + +# Add deps +add_dependencies(winml_lib_ort winml_sdk_cppwinrt) +add_dependencies(winml_lib_ort winml_api) +add_dependencies(winml_lib_ort winml_api_native) +add_dependencies(winml_lib_ort winml_api_native_internal) + +# Link libraries +target_link_libraries(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/packages/DirectML.0.0.1/build/DirectML.targets) +target_link_libraries(winml_lib_ort PRIVATE wil) + + ########################### # Add winml_adapter ########################### list(APPEND winml_adapter_files - ${winml_adapter_dir}/CpuOrtSessionBuilder.cpp - ${winml_adapter_dir}/CpuOrtSessionBuilder.h - ${winml_adapter_dir}/CustomRegistryHelper.h - ${winml_adapter_dir}/FeatureDescriptorFactory.cpp - ${winml_adapter_dir}/FeatureDescriptorFactory.h - ${winml_adapter_dir}/LotusEnvironment.cpp - ${winml_adapter_dir}/LotusEnvironment.h ${winml_adapter_dir}/pch.h - ${winml_adapter_dir}/WinMLAdapter.cpp - ${winml_adapter_dir}/WinMLAdapter.h - ${winml_adapter_dir}/ZeroCopyInputStreamWrapper.cpp - ${winml_adapter_dir}/ZeroCopyInputStreamWrapper.h + ${winml_adapter_dir}/winml_adapter_apis.h + ${winml_adapter_dir}/winml_adapter_c_api.h + ${winml_adapter_dir}/winml_adapter_c_api.cpp + ${winml_adapter_dir}/winml_adapter_dml.cpp + ${winml_adapter_dir}/winml_adapter_environment.cpp + ${winml_adapter_dir}/winml_adapter_execution_provider.cpp + ${winml_adapter_dir}/winml_adapter_model.cpp + ${winml_adapter_dir}/winml_adapter_model.h + ${winml_adapter_dir}/winml_adapter_session.cpp ) - + if (onnxruntime_USE_DML) list(APPEND winml_adapter_files - ${winml_adapter_dir}/AbiCustomRegistryImpl.cpp - ${winml_adapter_dir}/AbiCustomRegistryImpl.h - ${winml_adapter_dir}/DmlOrtSessionBuilder.cpp - ${winml_adapter_dir}/DmlOrtSessionBuilder.h - ) + ${winml_adapter_dir}/abi_custom_registry_impl.cpp + ${winml_adapter_dir}/abi_custom_registry_impl.h + ) endif(onnxruntime_USE_DML) add_library(winml_adapter ${winml_adapter_files}) @@ -329,6 +400,7 @@ target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir}) target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir}/pch) target_include_directories(winml_lib_api PRIVATE ${winml_adapter_dir}) target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_image_dir}/inc) +target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_ort_dir}/inc) target_include_directories(winml_lib_api PRIVATE ${winml_lib_telemetry_dir}/inc) target_include_directories(winml_lib_api PRIVATE ${winml_lib_common_dir}/inc) @@ -370,6 +442,19 @@ endif(onnxruntime_USE_DML) ########################### add_library(winml_lib_common STATIC + ${winml_lib_common_dir}/inc/common.h + ${winml_lib_common_dir}/inc/CommonDeviceHelpers.h + ${winml_lib_common_dir}/inc/cppwinrt_onnx.h + ${winml_lib_common_dir}/inc/dx.h + ${winml_lib_common_dir}/inc/errors.h + ${winml_lib_common_dir}/inc/iengine.h + ${winml_lib_common_dir}/inc/NamespaceAliases.h + ${winml_lib_common_dir}/inc/onnx.h + ${winml_lib_common_dir}/inc/PheonixSingleton.h + ${winml_lib_common_dir}/inc/StringHelpers.h + ${winml_lib_common_dir}/inc/WinMLTelemetryHelper.h + ${winml_lib_common_dir}/inc/WinML_Lock.h + ${winml_lib_common_dir}/inc/winrt_headers.h ${winml_lib_common_dir}/CommonDeviceHelpers.cpp ) @@ -448,6 +533,7 @@ target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/s target_include_directories(winml_dll PRIVATE ${winml_dll_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir}/impl) +target_include_directories(winml_dll PRIVATE ${winml_lib_api_ort_dir}/inc) target_include_directories(winml_dll PRIVATE ${winml_adapter_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_image_dir}/inc) target_include_directories(winml_dll PRIVATE ${winml_lib_telemetry_dir}/inc) @@ -514,6 +600,7 @@ target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) target_link_libraries(winml_dll PRIVATE winml_lib_api) target_link_libraries(winml_dll PRIVATE winml_lib_image) +target_link_libraries(winml_dll PRIVATE winml_lib_ort) target_link_libraries(winml_dll PRIVATE winml_lib_telemetry) target_link_libraries(winml_dll PRIVATE delayimp.lib) target_link_libraries(winml_dll PRIVATE ${DBGHELP}) diff --git a/include/onnxruntime/core/providers/winml/winml_provider_factory.h b/include/onnxruntime/core/providers/winml/winml_provider_factory.h index b4d4a754d2460..b08b42e310e41 100644 --- a/include/onnxruntime/core/providers/winml/winml_provider_factory.h +++ b/include/onnxruntime/core/providers/winml/winml_provider_factory.h @@ -1,14 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "onnxruntime_c_api.h" -#ifdef __cplusplus -#include -using namespace Windows::AI::MachineLearning::Adapter; -#else -struct IWinMLAdapter; -typedef struct IWinMLAdapter IWinMLAdapter; -#endif +#include "onnxruntime_c_api.h" -ORT_EXPORT STDAPI OrtGetWinMLAdapter(IWinMLAdapter** adapter); +struct WinmlAdapterApi; +typedef struct WinmlAdapterApi WinmlAdapterApi; +ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION; \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 176b988ad0f3b..ee05614229f05 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -156,6 +156,8 @@ ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); ORT_RUNTIME_CLASS(CustomOpDomain); +ORT_RUNTIME_CLASS(MapTypeInfo); +ORT_RUNTIME_CLASS(SequenceTypeInfo); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c466b0cb8a79c..a97a5d413f904 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -75,17 +75,11 @@ ORT_DEFINE_RELEASE(Value); // This is used internally by the C++ API. This is the common base class used by the wrapper objects. template struct Base { - Base() { - p_ = nullptr; - } + Base() = default; Base(T* p) : p_{p} { if (!p) throw Ort::Exception("Allocation failure", ORT_FAIL); } - ~Base() { - if (p_ != nullptr) { - OrtRelease(p_); - } - } + ~Base() { OrtRelease(p_); } operator T*() { return p_; } operator const T*() const { return p_; } @@ -96,19 +90,12 @@ struct Base { return p; } - T** put() noexcept { - assert(p_ == nullptr); - return &p_; - } - protected: Base(const Base&) = delete; Base& operator=(const Base&) = delete; Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } void operator=(Base&& v) noexcept { - if (p_ != nullptr) { - OrtRelease(p_); - } + OrtRelease(p_); p_ = v.p_; v.p_ = nullptr; } @@ -275,7 +262,6 @@ struct Value : Base { size_t GetStringTensorDataLength() const; void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; - std::vector GetStrings(); template T* GetTensorMutableData(); @@ -306,9 +292,6 @@ struct MemoryInfo : Base { MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); explicit MemoryInfo(OrtMemoryInfo* p) : Base{p} {} - - const char* Name() const; - OrtMemType MemType() const; }; // @@ -371,4 +354,4 @@ struct CustomOpBase : OrtCustomOp { } // namespace Ort -#include "onnxruntime_cxx_inline.h" +#include "onnxruntime_cxx_inline.h" \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index be2fe6bf9e2c2..f6fb350171f01 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -76,18 +76,6 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O ThrowOnError(Global::api_.CreateMemoryInfo(name, type, id, mem_type, &p_)); } -inline const char* MemoryInfo::Name() const { - const char* out = nullptr; - ThrowOnError(Global::api_.MemoryInfoGetName(p_, &out)); - return out; -} - -inline OrtMemType MemoryInfo::MemType() const { - OrtMemType out; - ThrowOnError(Global::api_.MemoryInfoGetMemType(p_, &out)); - return out; -} - inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) { ThrowOnError(Global::api_.CreateEnv(default_warning_level, logid, &p_)); } @@ -357,21 +345,6 @@ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -template <> -inline Value Value::CreateTensor(const OrtMemoryInfo*, std::string* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { - // convert the array of std::string to an array of const char * - std::vector string_vector; - for (size_t i = 0; i < p_data_element_count; ++i) { - string_vector.push_back(p_data[i].c_str()); - } - // now make an empty tensor using the default allocator (strings have to make a copy) - AllocatorWithDefaultOptions allocator; - auto tensor = Value::CreateTensor(static_cast(allocator), shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); - // now fill the string data - ThrowOnError(GetApi().FillStringTensor(tensor, string_vector.data(), string_vector.size())); - return tensor; -} - inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; @@ -444,33 +417,6 @@ inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, si ThrowOnError(Global::api_.GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); } -inline std::vector Value::GetStrings() { - std::vector out; - // make sure this is an array of strings - auto shape = this->GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - if (shape.size() != 1) throw Ort::Exception("shape.size() != 1", ORT_INVALID_ARGUMENT); - // make a big buffer to hold all the string data - size_t buflen = this->GetStringTensorDataLength(); - std::vector buf(buflen); - std::vector offsets(shape[0]); - this->GetStringTensorContent(buf.data(), buf.size(), offsets.data(), offsets.size()); - // now go build all the strings - for (auto i = 0; i < shape[0]; ++i) { - std::string str; - size_t strlen = 0; - // are we on the last one? - if (i == (shape[0] - 1ll)) { - strlen = buflen - offsets[i]; - } else { - strlen = offsets[i + 1ll] - offsets[i]; - } - str.append(reinterpret_cast(buf.data() + offsets[i]), strlen); - out.push_back(str); - } - return out; -} - template T* Value::GetTensorMutableData() { T* out; @@ -607,4 +553,4 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, return out; } -} // namespace Ort +} // namespace Ort \ No newline at end of file diff --git a/onnxruntime/core/framework/allocatormgr.cc b/onnxruntime/core/framework/allocatormgr.cc index f4258d5a6a889..a38d89a9e2bb8 100644 --- a/onnxruntime/core/framework/allocatormgr.cc +++ b/onnxruntime/core/framework/allocatormgr.cc @@ -29,9 +29,4 @@ AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id return AllocatorPtr(std::move(device_allocator)); } -DeviceAllocatorRegistry& DeviceAllocatorRegistry::Instance() { - static DeviceAllocatorRegistry s_instance; - return s_instance; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocatormgr.h b/onnxruntime/core/framework/allocatormgr.h index 3985fd4b66a98..aa346fc52f575 100644 --- a/onnxruntime/core/framework/allocatormgr.h +++ b/onnxruntime/core/framework/allocatormgr.h @@ -18,25 +18,4 @@ struct DeviceAllocatorRegistrationInfo { AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id = 0); -class DeviceAllocatorRegistry { - public: - void RegisterDeviceAllocator(std::string&& name, DeviceAllocatorFactory factory, size_t max_mem, - OrtMemType mem_type = OrtMemTypeDefault) { - DeviceAllocatorRegistrationInfo info({mem_type, factory, max_mem}); - device_allocator_registrations_.emplace(std::move(name), std::move(info)); - } - - const std::map& AllRegistrations() const { - return device_allocator_registrations_; - } - - static DeviceAllocatorRegistry& Instance(); - - private: - DeviceAllocatorRegistry() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeviceAllocatorRegistry); - - std::map device_allocator_registrations_; -}; - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc new file mode 100644 index 0000000000000..107cdbbed10c2 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type) noexcept : map_key_type_(map_key_type), map_value_type_(map_value_type, &OrtApis::ReleaseTypeInfo) { +} + +static ONNXTensorElementDataType +ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { + using TensorType = ONNX_NAMESPACE::TensorProto_DataType; + switch (data_type) { + case TensorType::TensorProto_DataType_BOOL: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; } + case TensorType::TensorProto_DataType_STRING: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; } // maps to c++ type std::string + case TensorType::TensorProto_DataType_FLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // maps to c type float + case TensorType::TensorProto_DataType_FLOAT: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; } + case TensorType::TensorProto_DataType_DOUBLE: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; } // maps to c type double + case TensorType::TensorProto_DataType_INT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; } // maps to c type int8_t + case TensorType::TensorProto_DataType_INT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; } // maps to c type int16_t + case TensorType::TensorProto_DataType_INT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } // maps to c type int32_t + case TensorType::TensorProto_DataType_INT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } // maps to c type int64_t + case TensorType::TensorProto_DataType_UINT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; } // maps to c type uint8_t + case TensorType::TensorProto_DataType_UINT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; } // maps to c type uint16_t + case TensorType::TensorProto_DataType_UINT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; } // maps to c type uint32_t + case TensorType::TensorProto_DataType_UINT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; } // maps to c type uint64_t + case TensorType::TensorProto_DataType_COMPLEX64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; } // complex with float32 real and imaginary components + case TensorType::TensorProto_DataType_COMPLEX128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } // complex with float64 real and imaginary components + case TensorType::TensorProto_DataType_BFLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; } // Non-IEEE floating-point format based on IEEE754 single-precision + default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + } +} + +OrtStatus* OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtMapTypeInfo** out) { + auto value_case = type_proto->value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type map!");; + } + + // Get the key type of the map + auto type_proto_map = type_proto->map_type(); + auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); + + // Get the value type of the map + OrtTypeInfo* map_value_type_info = nullptr; + if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_map.value_type(), &map_value_type_info)) + { + return status; + } + + *out = new OrtMapTypeInfo(map_key_type, map_value_type_info); + return nullptr; +} + +OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { + OrtTypeInfo* map_value_type_copy = nullptr; + if (auto status = map_value_type_->Clone(&map_value_type_copy)) + { + return status; + } + *out = new OrtMapTypeInfo(map_key_type_, map_value_type_copy); + return nullptr; +} + +// OrtMapTypeInfo Accessors +ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) { + API_IMPL_BEGIN + *out = map_type_info->map_key_type_; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN + return map_type_info->map_value_type_->Clone(out); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h new file mode 100644 index 0000000000000..2d9297c8cb2d4 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "onnxruntime_c_api.h" + +#include + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtMapTypeInfo { + public: + ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + std::unique_ptr map_value_type_; + + static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtMapTypeInfo** out); + + OrtStatus* Clone(OrtMapTypeInfo** out); + + private: + OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type)noexcept; + OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; + OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; + +}; diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc new file mode 100644 index 0000000000000..a5ee0c9a63bb1 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept : + sequence_key_type_(sequence_key_type, &OrtApis::ReleaseTypeInfo) { +} + +OrtStatus* OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtSequenceTypeInfo** out) { + auto value_case = type_proto->value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type sequence!");; + } + + auto type_proto_sequence = type_proto->sequence_type(); + OrtTypeInfo* sequence_key_type_info = nullptr; + if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_sequence.elem_type(), &sequence_key_type_info)) + { + return status; + } + + *out = new OrtSequenceTypeInfo(sequence_key_type_info); + return nullptr; +} + +OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { + OrtTypeInfo* sequence_key_type_copy = nullptr; + if (auto status = sequence_key_type_->Clone(&sequence_key_type_copy)) + { + return status; + } + *out = new OrtSequenceTypeInfo(sequence_key_type_copy); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN + return sequence_type_info->sequence_key_type_->Clone(out); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h new file mode 100644 index 0000000000000..6efa55c8de763 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "onnxruntime_c_api.h" + +#include + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtSequenceTypeInfo { + public: + std::unique_ptr sequence_key_type_; + + OrtStatus* Clone(OrtSequenceTypeInfo** out); + + static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtSequenceTypeInfo** out); + + private: + OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept; + OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete; + OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete; +}; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 080a3518048cb..42e03e802caf1 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,6 +10,11 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_sequence_type_info.h" using onnxruntime::BFloat16; using onnxruntime::DataTypeImpl; @@ -20,11 +25,27 @@ using onnxruntime::TensorShape; namespace on = ONNX_NAMESPACE; +OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { +} + OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) { } +OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtMapTypeInfo* map_type_info1) noexcept : type(type1), map_type_info(map_type_info1) { +} + +OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtSequenceTypeInfo* sequence_type_info1) noexcept : type(type1), sequence_type_info(sequence_type_info1) { +} + OrtTypeInfo::~OrtTypeInfo() { OrtApis::ReleaseTensorTypeAndShapeInfo(data); + + if (map_type_info) { + OrtApis::ReleaseMapTypeInfo(map_type_info); + } + if (sequence_type_info) { + OrtApis::ReleaseSequenceTypeInfo(sequence_type_info); + } } ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, ONNXType* out) { @@ -37,6 +58,28 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) { + API_IMPL_BEGIN + *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) { + API_IMPL_BEGIN + *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) { + API_IMPL_BEGIN + *out = type_info->denotation.c_str(); + *len = type_info->denotation.size(); + return nullptr; + API_IMPL_END +} + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { delete ptr; } @@ -49,7 +92,7 @@ OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vectorIsTensorSequenceType()) { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); + *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE); return nullptr; } @@ -92,16 +135,14 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr); + *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE); return nullptr; } case on::TypeProto::kMapType: { - *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr); - return nullptr; + return OrtTypeInfo::FromTypeProto(type_proto, out); } case on::TypeProto::kSequenceType: { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); - return nullptr; + return OrtTypeInfo::FromTypeProto(type_proto, out); } // Real Tensor support case on::TypeProto::kTensorType: @@ -204,19 +245,39 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info); } if (st != nullptr) return st; - *out = new OrtTypeInfo(ten_type, info); + auto type_info = new OrtTypeInfo(ten_type, info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kSequenceType: { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); + OrtSequenceTypeInfo* sequence_type_info = nullptr; + + if (auto status = OrtSequenceTypeInfo::FromTypeProto(input, &sequence_type_info)) { + return status; + } + + auto type_info = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kMapType: { - *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr); + OrtMapTypeInfo* map_type_info = nullptr; + + if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) { + return status; + } + + auto type_info = new OrtTypeInfo(ONNX_TYPE_MAP, map_type_info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr); + auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::VALUE_NOT_SET: @@ -227,3 +288,48 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or } return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); } + +OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { + switch (type) { + case ONNX_TYPE_TENSOR: + case ONNX_TYPE_SPARSETENSOR: + { + OrtTensorTypeAndShapeInfo* clone; + if (auto status = data->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_SEQUENCE: + { + OrtSequenceTypeInfo* clone; + if (auto status = sequence_type_info->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_MAP: { + OrtMapTypeInfo* clone; + if (auto status = map_type_info->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_OPAQUE: + { + *out = new OrtTypeInfo(type); + (*out)->denotation = denotation; + return nullptr; + } + default: + // Not implemented + break; + } + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index d615840dcb501..3c256aa73d17d 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -14,6 +15,10 @@ namespace ONNX_NAMESPACE { class TypeProto; } +// These types are only present in the winml adapter c api, so they are forward declared. +struct OrtMapTypeInfo; +struct OrtSequenceTypeInfo; + /** * the equivalent of ONNX_NAMESPACE::TypeProto * This class is mainly for the C API @@ -21,19 +26,26 @@ class TypeProto; struct OrtTypeInfo { public: ONNXType type = ONNX_TYPE_UNKNOWN; + std::string denotation; ~OrtTypeInfo(); //owned by this OrtTensorTypeAndShapeInfo* data = nullptr; + OrtMapTypeInfo* map_type_info = nullptr; + OrtSequenceTypeInfo* sequence_type_info = nullptr; OrtTypeInfo(const OrtTypeInfo& other) = delete; OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; + OrtStatus* Clone(OrtTypeInfo** out); + static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out); static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); - static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); private: + OrtTypeInfo(ONNXType type) noexcept; OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; + OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept; + OrtTypeInfo(ONNXType type, OrtSequenceTypeInfo* sequence_type_info) noexcept; }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 088043a159962..64bb11dbcbcfa 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -192,6 +192,11 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const st return GetTensorShapeAndTypeHelper(type, shape, dim_params, out); } +OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) +{ + return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out); +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN onnxruntime::MLDataType type = v->Type(); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 28431a9d614cf..f781160cc6505 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -13,4 +13,6 @@ struct OrtTensorTypeAndShapeInfo { OrtTensorTypeAndShapeInfo() = default; OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; + + OrtStatus* Clone(OrtTensorTypeAndShapeInfo** out); }; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 00dbab1536e86..079c444197289 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -26,14 +26,22 @@ struct DMLProviderFactory : IExecutionProviderFactory { ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; + void SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode); private: ComPtr dml_device_{}; ComPtr cmd_queue_{}; + AllocatorRoundingMode rounding_mode_ = AllocatorRoundingMode::Enabled; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - return Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + Dml::SetDefaultRoundingMode(provider.get(), rounding_mode_); + return provider; +} + +void DMLProviderFactory::SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode) { + rounding_mode_ = rounding_mode; } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, @@ -57,8 +65,12 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID return std::make_shared(dml_device, cmd_queue); } -bool IsSoftwareAdapter(IDXGIAdapter1* adapter) -{ +void DmlConfigureProviderFactoryDefaultRoundingMode(IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode) { + auto dml_prvider_factory = static_cast(factory); + dml_prvider_factory->SetDefaultRoundingMode(rounding_mode); +} + +bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { DXGI_ADAPTER_DESC1 desc; adapter->GetDesc1(&desc); @@ -96,7 +108,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG ComPtr debug_device; - (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure + (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); if (is_d3d12_debug_layer_enabled) { @@ -110,7 +122,6 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dml_device))); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b9d4714e000e4..482bea568a701 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -12,13 +12,13 @@ #include #include "core/common/logging/logging.h" -#include "core/common/logging/sinks/clog_sink.h" #include "core/common/status.h" #include "core/graph/graph.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ml_value.h" #include "core/session/environment.h" +#include "core/session/onnxruntime_env.h" #include "core/framework/callback.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" @@ -49,112 +49,6 @@ using namespace onnxruntime; if (_status) return _status; \ } while (0) -class LoggingWrapper : public ISink { - public: - LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) - : logging_function_(logging_function), logger_param_(logger_param) { - } - - void SendImpl(const Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, - const Capture& message) override { - std::string s = message.Location().ToString(); - logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), - logger_id.c_str(), s.c_str(), message.Message().c_str()); - } - - private: - OrtLoggingFunction logging_function_; - void* logger_param_; -}; - -struct OrtEnv { - public: - struct LoggingManagerConstructionInfo { - LoggingManagerConstructionInfo(OrtLoggingFunction logging_function1, - void* logger_param1, - OrtLoggingLevel default_warning_level1, - const char* logid1) - : logging_function(logging_function1), - logger_param(logger_param1), - default_warning_level(default_warning_level1), - logid(logid1) {} - OrtLoggingFunction logging_function{}; - void* logger_param{}; - OrtLoggingLevel default_warning_level; - const char* logid{}; - }; - - static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, Status& status) { - std::lock_guard lock(m_); - if (!p_instance_) { - std::unique_ptr env; - status = Environment::Create(env); - if (!status.IsOK()) { - return nullptr; - } - - std::unique_ptr lmgr; - std::string name = lm_info.logid; - if (lm_info.logging_function) { - std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, - lm_info.logger_param); - lmgr.reset(new LoggingManager(std::move(logger), - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } else { - lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } - - p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); - } - ++ref_count_; - return p_instance_; - } - - static void Release(OrtEnv* env_ptr) { - if (!env_ptr) { - return; - } - std::lock_guard lock(m_); - ORT_ENFORCE(env_ptr == p_instance_); // sanity check - --ref_count_; - if (ref_count_ == 0) { - delete p_instance_; - p_instance_ = nullptr; - } - } - - LoggingManager* GetLoggingManager() const { - return logging_manager_.get(); - } - - private: - static OrtEnv* p_instance_; - static OrtMutex m_; - static int ref_count_; - - std::unique_ptr value_; - std::unique_ptr logging_manager_; - - OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) - : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { - } - - ~OrtEnv() = default; - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); -}; - -OrtEnv* OrtEnv::p_instance_ = nullptr; -int OrtEnv::ref_count_ = 0; -OrtMutex OrtEnv::m_; - #define TENSOR_READ_API_BEGIN \ API_IMPL_BEGIN \ auto v = reinterpret_cast(value); \ @@ -1451,6 +1345,10 @@ static constexpr OrtApi ort_api_1 = { &OrtApis::ReleaseCustomOpDomain, }; +const OrtApi* GetVersion1Api() { + return &ort_api_1; +} + ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version > 1) return nullptr; @@ -1472,4 +1370,4 @@ ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) { DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) -DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) \ No newline at end of file diff --git a/onnxruntime/core/session/onnxruntime_env.cc b/onnxruntime/core/session/onnxruntime_env.cc new file mode 100644 index 0000000000000..c37a2543a8eed --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_env.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//this file contains implementations of the C API + +#include + +#include "onnxruntime_env.h" +#include "core/session/ort_apis.h" +#include "core/session/environment.h" +#include "core/common/logging/sinks/clog_sink.h" +#include "core/common/logging/logging.h" +#include "core/session/environment.h" + +using namespace onnxruntime; +using namespace onnxruntime::logging; + +OrtEnv* OrtEnv::p_instance_ = nullptr; +int OrtEnv::ref_count_ = 0; +onnxruntime::OrtMutex OrtEnv::m_; + +LoggingWrapper::LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) + : logging_function_(logging_function), logger_param_(logger_param) { +} + +void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) { + std::string s = message.Location().ToString(); + logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), + logger_id.c_str(), s.c_str(), message.Message().c_str()); +} + +OrtEnv::OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) + : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { +} + +OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status) { + std::lock_guard lock(m_); + if (!p_instance_) { + std::unique_ptr env; + status = onnxruntime::Environment::Create(env); + if (!status.IsOK()) { + return nullptr; + } + + std::unique_ptr lmgr; + std::string name = lm_info.logid; + if (lm_info.logging_function) { + std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, + lm_info.logger_param); + lmgr.reset(new LoggingManager(std::move(logger), + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } else { + lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } + + p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); + } + ++ref_count_; + return p_instance_; +} + +void OrtEnv::Release(OrtEnv* env_ptr) { + if (!env_ptr) { + return; + } + std::lock_guard lock(m_); + ORT_ENFORCE(env_ptr == p_instance_); // sanity check + --ref_count_; + if (ref_count_ == 0) { + delete p_instance_; + p_instance_ = nullptr; + } +} + +LoggingManager* OrtEnv::GetLoggingManager() const { + return logging_manager_.get(); +} + +void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { + std::lock_guard lock(m_); + logging_manager_ = std::move(logging_manager); +} \ No newline at end of file diff --git a/onnxruntime/core/session/onnxruntime_env.h b/onnxruntime/core/session/onnxruntime_env.h new file mode 100644 index 0000000000000..c93d2937c7a7b --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_env.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/isink.h" +#include "core/platform/ort_mutex.h" +#include "core/common/status.h" + +namespace onnxruntime { +class Environment; +} + +class LoggingWrapper : public onnxruntime::logging::ISink { + public: + LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param); + + void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) override; + + private: + OrtLoggingFunction logging_function_; + void* logger_param_; +}; + +struct OrtEnv { + public: + struct LoggingManagerConstructionInfo { + LoggingManagerConstructionInfo(OrtLoggingFunction logging_function1, + void* logger_param1, + OrtLoggingLevel default_warning_level1, + const char* logid1) + : logging_function(logging_function1), + logger_param(logger_param1), + default_warning_level(default_warning_level1), + logid(logid1) {} + OrtLoggingFunction logging_function{}; + void* logger_param{}; + OrtLoggingLevel default_warning_level; + const char* logid{}; + }; + + static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status); + + static void Release(OrtEnv* env_ptr); + + onnxruntime::logging::LoggingManager* GetLoggingManager() const; + + void SetLoggingManager(std::unique_ptr logging_manager); + + private: + static OrtEnv* p_instance_; + static onnxruntime::OrtMutex m_; + static int ref_count_; + + std::unique_ptr value_; + std::unique_ptr logging_manager_; + + OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager); + ~OrtEnv() = default; + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); +}; \ No newline at end of file diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index cdc1ea7b6900f..4e3bf2274aaf4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -16,6 +16,8 @@ ORT_API(void, ReleaseTypeInfo, OrtTypeInfo*); ORT_API(void, ReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo*); ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*); ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*); +ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*); +ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*); ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg); OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; @@ -144,4 +146,16 @@ ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* c ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +// OrtTypeInfo methods +ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); +ORT_API_STATUS_IMPL(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out); +ORT_API_STATUS_IMPL(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out); + +// OrtMapTypeInfo Accessors +ORT_API_STATUS_IMPL(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); +ORT_API_STATUS_IMPL(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); + +// OrtSequenceTypeInfo Accessors +ORT_API_STATUS_IMPL(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info); + } // namespace OrtApis diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index de7185e091f93..6a10e2f0f6707 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -1,14 +1,4 @@ jobs: -- template: templates/win-ci.yml - parameters: - AgentPool : 'Win-CPU' - DoDebugBuild: 'true' - DoCompliance: 'false' - BuildCommand: '$(Build.SourcesDirectory)\tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --use_tvm --use_automl --enable_pybind --use_mkldnn --use_openmp --use_winml --build_shared_lib --build_csharp --enable_onnx_tests' - JobName: 'Windows_CI_Dev' - DoNugetPack: 'false' - NuPackScript : '' - DoTestCoverage: 'false' - job: 'build' pool: 'Win-CPU-2019' strategy: @@ -66,7 +56,7 @@ jobs: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --build_wheel --use_featurizers --use_dnnl --use_openmp --build_shared_lib --enable_onnx_tests --build_java' + arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --build_wheel --use_featurizers --use_dnnl --use_winml --use_openmp --build_shared_lib --enable_onnx_tests --build_java' workingDirectory: '$(Build.BinariesDirectory)' - task: VSBuild@1 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 0bf9c923e78b9..3ef5bef8c751f 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -42,7 +42,8 @@ jobs: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --use_featurizers --use_dnnl --build_shared_lib --enable_onnx_tests --use_dml --use_cuda --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' + arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --use_featurizers + nnl --build_shared_lib --enable_onnx_tests --use_dml --use_winml --use_cuda --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' workingDirectory: '$(Build.BinariesDirectory)' - task: VSBuild@1 diff --git a/winml/adapter/CpuOrtSessionBuilder.cpp b/winml/adapter/CpuOrtSessionBuilder.cpp deleted file mode 100644 index 72d09ff022941..0000000000000 --- a/winml/adapter/CpuOrtSessionBuilder.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -// Needed to work around the fact that OnnxRuntime defines ERROR -#ifdef ERROR -#undef ERROR -#endif -#include "core/session/inference_session.h" -// Restore ERROR define -#define ERROR 0 - -#include "CpuOrtSessionBuilder.h" -#include "WinMLAdapter.h" -#include "WinMLAdapterErrors.h" - -// winml includes -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" - -// ort includes -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/optimizer/conv_activation_fusion.h" -#include "core/optimizer/gemm_activation_fusion.h" -#include "core/session/abi_session_options_impl.h" - -using namespace Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -CpuOrtSessionBuilder::CpuOrtSessionBuilder() { - -} - -HRESULT -CpuOrtSessionBuilder::CreateSessionOptions( - OrtSessionOptions** options) try { - RETURN_HR_IF_NULL(E_POINTER, options); - - Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); - Ort::SessionOptions session_options(*options); - - // set the graph optimization level to all (used to be called level 3) - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - - // Onnxruntime will use half the number of concurrent threads supported on the system - // by default. This causes MLAS to not exercise every logical core. - // We force the thread pool size to be maxxed out to ensure that WinML always - // runs the fastest. - session_options.SetIntraOpNumThreads(std::thread::hardware_concurrency()); - - // call release() so the underlying OrtSessionOptions object isn't freed - session_options.release(); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT -CpuOrtSessionBuilder::CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) try { - RETURN_HR_IF_NULL(E_POINTER, p_session); - RETURN_HR_IF_NULL(E_POINTER, pp_provider); - RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); - - // Create the inference session - auto session = std::make_unique(options->value); - - // Create the cpu execution provider - onnxruntime::CPUExecutionProviderInfo xpInfo; -#ifndef _WIN64 - xpInfo.create_arena = false; -#endif - auto cpu_provider = std::make_unique(xpInfo); - - // Cache the provider's raw pointer - *pp_provider = cpu_provider.get(); - - // Register the cpu xp - ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(cpu_provider))); - - // assign the session to the out parameter - auto sessionptr = wil::MakeOrThrow(session.release()); - RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session)); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT -CpuOrtSessionBuilder::Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* /*p_provider*/ -) try { - ORT_THROW_IF_ERROR(p_session->get()->Initialize()); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/CpuOrtSessionBuilder.h b/winml/adapter/CpuOrtSessionBuilder.h deleted file mode 100644 index 700129275f490..0000000000000 --- a/winml/adapter/CpuOrtSessionBuilder.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "WinMLAdapter.h" - -namespace Windows::AI::MachineLearning::Adapter { - -class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - winmla::IOrtSessionBuilder> { - - public: - CpuOrtSessionBuilder(); - - HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions** options) override; - - HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) override; - - HRESULT STDMETHODCALLTYPE Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) override; -}; - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/CustomRegistryHelper.h b/winml/adapter/CustomRegistryHelper.h deleted file mode 100644 index de2987e676447..0000000000000 --- a/winml/adapter/CustomRegistryHelper.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" - -namespace Windows::AI::MachineLearning::Adapter { - -inline std::list> -GetLotusCustomRegistries( - IMLOperatorRegistry* registry) { - if (registry != nullptr) { - // Down-cast to the concrete type. - // The only supported input is the AbiCustomRegistry type. - // Other implementations of IMLOperatorRegistry are forbidden. - auto abi_custom_registry = - static_cast(registry); - - // Get the ORT registry - return abi_custom_registry->GetRegistries(); - } - - return {}; -} - -} // namespace Windows::AI::MachineLearning::Adapter - -#endif USE_DML diff --git a/winml/adapter/DmlOrtSessionBuilder.cpp b/winml/adapter/DmlOrtSessionBuilder.cpp deleted file mode 100644 index 29da1a6332642..0000000000000 --- a/winml/adapter/DmlOrtSessionBuilder.cpp +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -#ifdef USE_DML - -// Needed to work around the fact that OnnxRuntime defines ERROR -#ifdef ERROR -#undef ERROR -#endif -#include "core/session/inference_session.h" -// Restore ERROR define -#define ERROR 0 - -#include "DmlOrtSessionBuilder.h" -#include "WinMLAdapterErrors.h" - -// winml includes -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" -#include "CustomRegistryHelper.h" -#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" -#include "LearningModelDevice.h" -#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" - -// ort includes -#include "core/framework/op_kernel.h" -#include "core/framework/op_node_proto_helper.h" -#include "core/framework/customRegistry.h" -#include "core/framework/data_transfer.h" -#include "core/session/abi_session_options_impl.h" - -using namespace Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -DmlOrtSessionBuilder::DmlOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue) { - device_.copy_from(device); - queue_.copy_from(queue); -} - -HRESULT -DmlOrtSessionBuilder::CreateSessionOptions( - OrtSessionOptions** options) try { - RETURN_HR_IF_NULL(E_POINTER, options); - - Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); - Ort::SessionOptions session_options(*options); - - // set the graph optimization level to all (used to be called level 3) - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - - // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. - session_options.DisableMemPattern(); - - // call release() so the underlying OrtSessionOptions object isn't freed - session_options.release(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -static HRESULT -RegisterCustomRegistry( - onnxruntime::InferenceSession* p_session, - IMLOperatorRegistry* registry) { - if (registry != nullptr) { - RETURN_HR_IF_NULL(E_POINTER, p_session); - - auto custom_registries = GetLotusCustomRegistries(registry); - - // Register - for (auto& custom_registry : custom_registries) { - ORT_THROW_IF_ERROR(p_session->RegisterCustomRegistry(custom_registry)); - } - } - - return S_OK; -} - -Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { - // Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll - wil::unique_hmodule dmlDll(LoadLibraryW(L"DirectML.dll")); - THROW_LAST_ERROR_IF(!dmlDll); - - auto dmlCreateDevice1Fn = reinterpret_cast( - GetProcAddress(dmlDll.get(), "DMLCreateDevice1")); - THROW_LAST_ERROR_IF(!dmlCreateDevice1Fn); - - DML_CREATE_DEVICE_FLAGS dmlFlags = DML_CREATE_DEVICE_FLAG_NONE; - - // Enable the DML debug layer in DEBUG builds, if the D3D12 debug layer is also enabled -#if _DEBUG - Microsoft::WRL::ComPtr d3d12DebugDevice; - if (SUCCEEDED(d3d12Device->QueryInterface(IID_PPV_ARGS(&d3d12DebugDevice)))) { - d3d12DebugDevice = nullptr; - dmlFlags |= DML_CREATE_DEVICE_FLAG_DEBUG; - } -#endif - - Microsoft::WRL::ComPtr dmlDevice; - THROW_IF_FAILED(dmlCreateDevice1Fn(d3d12Device, dmlFlags, DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dmlDevice))); - - // Keep DirectML.dll loaded by leaking the handle. This is equivalent behavior to if we delay-loaded the DLL. - dmlDll.release(); - - return dmlDevice; -} - -HRESULT DmlOrtSessionBuilder::CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) try { - RETURN_HR_IF_NULL(E_POINTER, p_session); - RETURN_HR_IF_NULL(E_POINTER, pp_provider); - RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); - - auto p_d3d_device = device_.get(); - auto p_queue = queue_.get(); - - Microsoft::WRL::ComPtr dmlDevice = CreateDmlDevice(p_d3d_device); - - std::unique_ptr gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue); - auto session = std::make_unique(options->value); - - const onnxruntime::Env& env = onnxruntime::Env::Default(); - LUID temp_LUID = p_d3d_device->GetAdapterLuid(); - env.GetTelemetryProvider().LogExecutionProviderEvent(&temp_LUID); - // Cache the provider's raw pointer - *pp_provider = gpu_provider.get(); - - ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(gpu_provider))); - - // assign the session to the out parameter - auto sessionptr = wil::MakeOrThrow(session.release()); - RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session)); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT DmlOrtSessionBuilder::Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) try { - RETURN_HR_IF_NULL(E_INVALIDARG, p_session); - RETURN_HR_IF_NULL(E_INVALIDARG, p_provider); - - // OnnxRuntime uses the default rounding mode when calling the session's allocator. - // During initialization, OnnxRuntime allocates weights, which are permanent across session - // lifetime and can be large, so shouldn't be rounded. - Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Disabled); - - ORT_THROW_IF_ERROR(p_session->get()->Initialize()); - - Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Enabled); - - // Flush the D3D12 work from the DML execution provider - Dml::FlushContext(p_provider); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // namespace Windows::AI::MachineLearning::Adapter - -#endif USE_DML \ No newline at end of file diff --git a/winml/adapter/DmlOrtSessionBuilder.h b/winml/adapter/DmlOrtSessionBuilder.h deleted file mode 100644 index a02d1c21f8800..0000000000000 --- a/winml/adapter/DmlOrtSessionBuilder.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "WinMLAdapter.h" - -namespace Windows::AI::MachineLearning::Adapter { - -class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - winmla::IOrtSessionBuilder> { - - public: - DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue); - - HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions** options) override; - - HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) override; - - HRESULT STDMETHODCALLTYPE Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) override; - - private: - winrt::com_ptr device_; - winrt::com_ptr queue_; -}; - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/FeatureDescriptorFactory.h b/winml/adapter/FeatureDescriptorFactory.h deleted file mode 100644 index 497f92d9cbc8b..0000000000000 --- a/winml/adapter/FeatureDescriptorFactory.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once -#include "pch.h" - -namespace Windows::AI::MachineLearning { - -struct FeatureDescriptorFactory { - FeatureDescriptorFactory( - const std::unordered_map& model_metadata); - - wfc::IVector - CreateDescriptorsFromValueInfoProtos( - const std::vector& value_info_protos); - - private: - const std::unordered_map& metadata_; -}; - -} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/LotusEnvironment.cpp b/winml/adapter/LotusEnvironment.cpp deleted file mode 100644 index 30e3af20c46fc..0000000000000 --- a/winml/adapter/LotusEnvironment.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" -#include "LotusEnvironment.h" -#include "core/platform/windows/TraceLoggingConfig.h" -#include - -bool Windows::AI::MachineLearning::CWinMLLogSink::debug_output_ = false; -void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl( - const onnxruntime::logging::Timestamp& timestamp, - const std::string& logger_id, - const onnxruntime::logging::Capture& message) { - // ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry. - switch (message.Severity()) { - case (onnxruntime::logging::Severity::kFATAL): //Telemetry - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str()), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); - break; - case (onnxruntime::logging::Severity::kERROR): //Telemetry - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str()), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); - break; - case (onnxruntime::logging::Severity::kWARNING): - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_WARNING), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - break; - case (onnxruntime::logging::Severity::kINFO): - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - break; - case (onnxruntime::logging::Severity::kVERBOSE): - __fallthrough; //Default is Verbose too. - default: - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - } - if (debug_output_) { - OutputDebugStringA(std::string(message.Message() + "\r\n").c_str()); - } -} - -void Windows::AI::MachineLearning::CWinMLLogSink::SendProfileEvent(onnxruntime::profiling::EventRecord& eventRecord) const { - if (eventRecord.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) { - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "OnnxRuntimeProfiling", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(onnxruntime::profiling::event_categor_names_[eventRecord.cat], "Category"), - TraceLoggingInt64(eventRecord.dur, "Duration (us)"), - TraceLoggingInt64(eventRecord.ts, "Time Stamp (us)"), - TraceLoggingString(eventRecord.name.c_str(), "Event Name"), - TraceLoggingInt32(eventRecord.pid, "Process ID"), - TraceLoggingInt32(eventRecord.tid, "Thread ID"), - TraceLoggingString(eventRecord.args["op_name"].c_str(), "Operator Name"), - TraceLoggingString(eventRecord.args["provider"].c_str(), "Execution Provider")); - } else { - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "OnnxRuntimeProfiling", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(onnxruntime::profiling::event_categor_names_[eventRecord.cat], "Category"), - TraceLoggingInt64(eventRecord.dur, "Duration (us)"), - TraceLoggingInt64(eventRecord.ts, "Time Stamp (us)"), - TraceLoggingString(eventRecord.name.c_str(), "Event Name"), - TraceLoggingInt32(eventRecord.pid, "Process ID"), - TraceLoggingInt32(eventRecord.tid, "Thread ID")); - } -} diff --git a/winml/adapter/LotusEnvironment.h b/winml/adapter/LotusEnvironment.h deleted file mode 100644 index 37bb8ad7ed584..0000000000000 --- a/winml/adapter/LotusEnvironment.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once -#include "core/common/logging/isink.h" -#include -#include -#include "WinMLAdapter.h" - -#pragma warning(push) -#pragma warning(disable : 4505) - -namespace Windows { -namespace AI { -namespace MachineLearning { -class CWinMLLogSink : public onnxruntime::logging::ISink { - public: - CWinMLLogSink() { - } - static void EnableDebugOutput() { - debug_output_ = true; - OutputDebugStringW(L"Windows.AI.MachineLearning: Debug Output Enabled \r\n"); - } - void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const; - void SendImpl(const onnxruntime::logging::Timestamp& timestamp, const std::string& logger_id, const onnxruntime::logging::Capture& message); - - private: - static bool debug_output_; -}; -// TODO: a bug in ORT requires a logging manager. This function registers a static singleton logger as "default" -inline onnxruntime::logging::LoggingManager& DefaultLoggingManager() { - // create a CLog based default logging manager - static std::string default_logger_id{"Default"}; - static onnxruntime::logging::LoggingManager default_logging_manager{ - std::unique_ptr{new CWinMLLogSink()}, - onnxruntime::logging::Severity::kVERBOSE, - false, - onnxruntime::logging::LoggingManager::InstanceType::Default, - &default_logger_id, - MAXINT32}; - - return default_logging_manager; -} - -class LotusEnvironment { - public: - LotusEnvironment() { - const HRESULT etw_status = TraceLoggingRegister(winmla::winml_trace_logging_provider); - if (FAILED(etw_status)) { - throw std::runtime_error("WinML TraceLogging registration failed. Logging will be broken: " + std::to_string(etw_status)); - } - - // TODO: Do we need to call this or just define the method? - default_logging_manager_ = &DefaultLoggingManager(); - - if (!onnxruntime::Environment::Create(lotus_environment_).IsOK()) { - throw winrt::hresult_error(E_FAIL); - } - - auto allocatorMap = onnxruntime::DeviceAllocatorRegistry::Instance().AllRegistrations(); - if (allocatorMap.find("Cpu") == allocatorMap.end()) { - onnxruntime::DeviceAllocatorRegistry::Instance().RegisterDeviceAllocator( - "Cpu", - [](int) { return std::make_unique(); }, - std::numeric_limits::max()); - } - } - - ~LotusEnvironment() { - TraceLoggingUnregister(winmla::winml_trace_logging_provider); - } - - const onnxruntime::logging::Logger* GetDefaultLogger() { - return &default_logging_manager_->DefaultLogger(); - } - - private: - std::unique_ptr lotus_environment_; - onnxruntime::logging::LoggingManager* default_logging_manager_; -}; - -namespace ExecutionProviders { -__declspec(selectany) const char* CPUExecutionProvider = "CPUExecutionProvider"; -} - -} // namespace MachineLearning -} // namespace AI -} // namespace Windows - -#pragma warning(pop) \ No newline at end of file diff --git a/winml/adapter/WinMLAdapter.cpp b/winml/adapter/WinMLAdapter.cpp deleted file mode 100644 index eef56c60799a8..0000000000000 --- a/winml/adapter/WinMLAdapter.cpp +++ /dev/null @@ -1,759 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" -#include "WinMLAdapter.h" -#include "WinMLAdapterErrors.h" -#include "CustomRegistryHelper.h" -#include "PheonixSingleton.h" -#include "LotusEnvironment.h" -#include "AbiCustomRegistryImpl.h" - -#ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" -#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" -#include "DmlOrtSessionBuilder.h" -#endif USE_DML - -#include "LearningModelDevice.h" -#include "TensorFeatureDescriptor.h" -#include "ImageFeatureDescriptor.h" -#include "api.image/inc/D3DDeviceCache.h" -#include "Common/inc/WinMLTelemetryHelper.h" - -#include "CpuOrtSessionBuilder.h" - -#include -#include - -#include "ZeroCopyInputStreamWrapper.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" - -#include "FeatureDescriptorFactory.h" -#include "core\framework\utils.h" -#include "core\framework\session_state.h" -#include "core/providers/winml/winml_provider_factory.h" - -using namespace winrt::Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -// Define winml trace logging provider with WinML GUID -TRACELOGGING_DEFINE_PROVIDER( - winml_trace_logging_provider, - WINML_PROVIDER_DESC, - WINML_PROVIDER_GUID); - -// ORT intentionally requires callers derive from their session class to access -// the protected methods used below. -class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession { - public: - onnxruntime::common::Status - Load(std::unique_ptr p_model_proto) { - return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); - } - const onnxruntime::SessionState& GetSessionState() { - return *session_state_; - } -}; - -class ModelProto : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelProto> { - public: - ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) { - } - - onnx::ModelProto* STDMETHODCALLTYPE get() noexcept override { - return model_proto_.get(); - } - - onnx::ModelProto* STDMETHODCALLTYPE detach() noexcept override { - return model_proto_.release(); - } - - private: - std::unique_ptr model_proto_; -}; // class ModelProto - -class ModelInfo : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelInfo> { - private: - std::string author_; - std::string name_; - std::string domain_; - std::string description_; - int64_t version_; - std::unordered_map model_metadata_; - wfc::IVector input_features_; - wfc::IVector output_features_; - - public: - ModelInfo(const onnx::ModelProto* model_proto) { - Initialize(model_proto); - } - - const char* STDMETHODCALLTYPE author() noexcept override { - return author_.c_str(); - } - - const char* STDMETHODCALLTYPE name() noexcept override { - return name_.c_str(); - } - - const char* STDMETHODCALLTYPE domain() noexcept override { - return domain_.c_str(); - } - - const char* STDMETHODCALLTYPE description() noexcept override { - return description_.c_str(); - } - - int64_t STDMETHODCALLTYPE version() noexcept override { - return version_; - } - - HRESULT STDMETHODCALLTYPE GetModelMetadata( - ABI::Windows::Foundation::Collections::IMapView** metadata) override try { - *metadata = nullptr; - std::unordered_map map_copy; - for (auto& pair : model_metadata_) { - auto key = WinML::Strings::HStringFromUTF8(pair.first); - auto map_value = WinML::Strings::HStringFromUTF8(pair.second); - map_copy.emplace(std::move(key), std::move(map_value)); - } - auto out = winrt::single_threaded_map( - std::move(map_copy)); - - winrt::copy_to_abi(out.GetView(), *(void**)metadata); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetInputFeatures( - ABI::Windows::Foundation::Collections::IVectorView** features) override try { - *features = nullptr; - winrt::copy_to_abi(input_features_.GetView(), *(void**)features); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetOutputFeatures( - ABI::Windows::Foundation::Collections::IVectorView** features) override try { - *features = nullptr; - winrt::copy_to_abi(output_features_.GetView(), *(void**)features); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - static std::vector - GetAllNodeOutputs(const onnx::ModelProto& model_proto) { - std::vector nodes_outputs; - auto& graph = model_proto.graph(); - auto& nodes = graph.node(); - for (auto& node : nodes) { - for (auto& node_output : node.output()) { - nodes_outputs.push_back(node_output.c_str()); - } - } - return nodes_outputs; - } - - static std::vector - GetInitializers(const onnx::ModelProto& model_proto) { - std::vector initializers; - auto& graph = model_proto.graph(); - auto& graph_initializers = graph.initializer(); - for (auto& initializer : graph_initializers) { - initializers.push_back(initializer.name().c_str()); - } - return initializers; - } - - static std::vector - GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { - auto initializers = GetInitializers(model_proto); - - std::vector inputs_without_initializers; - auto& graph = model_proto.graph(); - auto& inputs = graph.input(); - for (auto& input : inputs) { - if (input.has_name() && input.has_type()) { - auto found_it = std::find_if( - std::begin(initializers), - std::end(initializers), - [&](auto& initializer) { - return std::strcmp(initializer, input.name().c_str()) == 0; - }); - - auto is_initializer = found_it != std::end(initializers); - if (!is_initializer) { - inputs_without_initializers.push_back(&input); - } - } - } - return inputs_without_initializers; - } - - static std::vector GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; - auto& graph = model_proto.graph(); - auto& outputs = graph.output(); - for (auto& output : outputs) { - if (output.has_name() && output.has_type()) { - outputs_with_name.push_back(&output); - } - } - return outputs_with_name; - } - - private: - void Initialize(const onnx::ModelProto* model_proto) { - // metadata - for (auto& prop : model_proto->metadata_props()) { - model_metadata_[prop.key()] = prop.value(); - } - - WinML::FeatureDescriptorFactory builder(model_metadata_); - - // Create inputs - auto inputs = GetInputsWithoutInitializers(*model_proto); - input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); - - // Create outputs - auto outputs = GetOutputs(*model_proto); - output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); - - // author - auto has_producer_name = model_proto->has_producer_name(); - author_ = has_producer_name - ? model_proto->producer_name() - : ""; - - // domain - auto has_domain = model_proto->has_domain(); - domain_ = has_domain - ? model_proto->domain() - : ""; - - // name - auto has_graph = model_proto->has_graph(); - auto graph_has_name = model_proto->graph().has_name(); - auto is_name_available = has_graph && graph_has_name; - name_ = is_name_available - ? model_proto->graph().name() - : ""; - - // description - auto has_description = model_proto->has_doc_string(); - description_ = has_description - ? model_proto->doc_string() - : ""; - - // version - auto has_version = model_proto->has_model_version(); - version_ = has_version - ? model_proto->model_version() - : 0; - } -}; // class ModelInfo - -class WinMLAdapter : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IWinMLAdapter> { - private: - // TODO: Making this static is only temporary. A fix addressing the resulting the memory leaks is needed. - static std::shared_ptr lotus_environment_; - - public: - WinMLAdapter() { - if (lotus_environment_ == nullptr) { - lotus_environment_ = PheonixSingleton(); - } - } - // factory methods for creating an ort model from a path - HRESULT STDMETHODCALLTYPE CreateModelProto( - const char* path, - IModelProto** model_proto) override try { - int file_descriptor; - _set_errno(0); // clear errno - _sopen_s( - &file_descriptor, - path, - O_RDONLY | _O_SEQUENTIAL | _O_BINARY, - _SH_DENYWR, - _S_IREAD | _S_IWRITE); - - errno_t err = 0; - _get_errno(&err); - THROW_HR_IF_MSG( - __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), - err == ENOENT, - "File not found: %s", - path); - - THROW_HR_IF_MSG( - E_FAIL, - 0 > file_descriptor, - "Failed"); //errno - - auto stream = google::protobuf::io::FileInputStream(file_descriptor); - stream.SetCloseOnDelete(true); - - auto model_proto_inner = new onnx::ModelProto(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&stream) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - // factory methods for creating an ort model from a stream - HRESULT STDMETHODCALLTYPE CreateModelProto( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, - IModelProto** model_proto) override try { - ZeroCopyInputStreamWrapper wrapper(stream_reference); - - auto model_proto_inner = std::make_unique(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner.release()); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - // factory methods for creating an ort model from a model_proto - HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override try { - auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override try { - auto model_info_outer = wil::MakeOrThrow(model_proto->get()); - return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast(model_info)); - } - WINMLA_CATCH_ALL_COM - - void STDMETHODCALLTYPE EnableDebugOutput() override try { - WinML::CWinMLLogSink::EnableDebugOutput(); - } - WINMLA_CATCH_ALL_DONOTHING - - static bool IsFeatureDescriptorFp16( - winml::ILearningModelFeatureDescriptor descriptor) { - if (auto imageFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == imageFeatureDescriptor.TensorKind(); - } - - if (auto tensorFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind(); - } - - return false; - } - - HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( - winml::LearningModel const& model, - IModelProto* p_model_proto, - bool is_float16_supported) override try { - if (!is_float16_supported) { - auto& graph = p_model_proto->get()->graph(); - - // The model will not contain fp16 operations if: - // 1. The model has no fp16 inputs - // 2. The model has no fp16 initializers - // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator - // 4. The model does not have any fp16 outputs - - // 1. Ensure that The model has no fp16 inputs - for (auto descriptor : model.InputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } - - // 2. Ensure that the model has no fp16 initializers - for (int i = 0; i < graph.node_size(); i++) { - auto node = graph.node(i); - if (node.op_type() == "Cast" && node.domain().empty()) { - for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { - auto attribute = node.attribute(attribIndex); - if (attribute.name() == "to") { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.", - node.name().c_str()); - } - } - } - } - - // 3. Ensure that the model does not create any fp16 intermediary - // tensors via the Cast (to float16) operator - for (int i = 0; i < graph.initializer_size(); i++) { - auto initializer = graph.initializer(i); - - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.", - initializer.name().c_str()); - } - - // 4. Ensure that the model does not have any fp16 outputs - for (auto descriptor : model.OutputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override try { -#ifdef USE_DML - auto d3dResource = - Dml::GetD3D12ResourceFromAllocation( - provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), - allocation); - return d3dResource; -#else - return nullptr; -#endif USE_DML - } catch (...) { - return nullptr; - } - - static onnxruntime::MLDataType GetType(winml::TensorKind kind) { - switch (kind) { - case winml::TensorKind::Float: - return onnxruntime::DataTypeImpl::GetType(); - case winml::TensorKind::Float16: - return onnxruntime::DataTypeImpl::GetType(); - }; - return nullptr; - } - - // factory method for creating an ortsessionbuilder from a device - HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue, - IOrtSessionBuilder** session_builder) override try { - if (device == nullptr) { - auto builder = wil::MakeOrThrow(); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } -#ifdef USE_DML - else { - auto builder = wil::MakeOrThrow(device, queue); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try { - *key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - auto type = ort_value->Type(); - if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try { - *key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - auto type = ort_value->Type(); - if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override try { -#ifdef USE_DML - auto impl = wil::MakeOrThrow(); - *registry = impl.Detach(); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative* operator_provider_native, IMLOperatorRegistry** registry) override try { -#ifdef USE_DML - // Retrieve the "operator abi" registry. - winrt::com_ptr operator_registry; - THROW_IF_FAILED(operator_provider_native->GetRegistry(operator_registry.put())); - *registry = operator_registry.detach(); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override try { -#ifdef USE_DML - return Dml::CreateGPUAllocationFromD3DResource(pResource); -#else - return nullptr; -#endif USE_DML - } catch (...) { - return nullptr; - } - - void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override try { -#ifdef USE_DML - Dml::FreeGPUAllocation(ptr); -#endif USE_DML - } - WINMLA_CATCH_ALL_DONOTHING - - HRESULT STDMETHODCALLTYPE CopyTensor( - onnxruntime::IExecutionProvider* provider, - OrtValue* src, - OrtValue* dst) override try { -#ifdef USE_DML - ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, *(src->GetMutable()), *(dst->GetMutable()))); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - // Override select shape inference functions which are incomplete in ONNX with versions that are complete, - // and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being - // deferred until first evaluation. It also prevents a situation where inference functions in externally - // registered schema are reachable only after upstream schema have been revised in a later OS release, - // which would be a compatibility risk. - HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override try { -#ifdef USE_DML - static std::once_flag schema_override_once_flag; - std::call_once(schema_override_once_flag, []() { - SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); - }); - return S_OK; -#else - return S_OK; // needs to return S_OK otherwise everything breaks because this gets called from the learningmodel constructor -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo( - onnxruntime::IExecutionProvider* provider, - OrtMemoryInfo** memory_info) override try { - auto allocator = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); - - const auto& info = allocator->Info(); - *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); - if (*memory_info == nullptr) { - return E_OUTOFMEMORY; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue* ort_value, OrtMemoryInfo** memory_info) override try { - const auto& tensor = ort_value->Get(); - auto info = tensor.Location(); - *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); - if (*memory_info == nullptr) { - return E_OUTOFMEMORY; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - struct AllocatorWrapper : public OrtAllocator { - public: - AllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) { - version = ORT_API_VERSION; - Alloc = AllocImpl; - Free = FreeImpl; - Info = InfoImpl; - } - - static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { - return static_cast(this_)->impl_->Alloc(size); - } - static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { - return static_cast(this_)->impl_->Free(p); - } - static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { - return &(static_cast(this_)->impl_->Info()); - } - - private: - onnxruntime::AllocatorPtr impl_; - }; - - HRESULT STDMETHODCALLTYPE GetProviderAllocator( - onnxruntime::IExecutionProvider* provider, - OrtAllocator** allocator) override try { - auto allocator_ptr = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); - *allocator = new (std::nothrow) AllocatorWrapper(allocator_ptr); - if (*allocator == nullptr) { - return E_OUTOFMEMORY; - } - - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE FreeProviderAllocator( - OrtAllocator* allocator) override try { - delete static_cast(allocator); - return S_OK; - } - WINMLA_CATCH_ALL_COM -}; // namespace Windows::AI::MachineLearning::Adapter -std::shared_ptr WinMLAdapter::lotus_environment_ = nullptr; - -extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) try { - // make an adapter instance - Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); - return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast(adapter)); -} -WINMLA_CATCH_ALL_COM - -// InferenceSession -// ================ - -InferenceSession::InferenceSession(onnxruntime::InferenceSession* session) : session_(session) { -} - -void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers() try { -#ifdef USE_DML - // Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT - GraphTransformerHelpers::RegisterGraphTransformers(session_.get()); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() try { - this->session_->StartProfiling(PheonixSingleton()->GetDefaultLogger()); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() try { - this->session_->EndProfiling(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE -InferenceSession::LoadModel( - IModelProto* model_proto) try { - auto session_protected_load_accessor = - static_cast(session_.get()); - // session's like to have their very own copy of the model_proto, use detach() - std::unique_ptr model_proto_ptr(model_proto->detach()); - ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE -InferenceSession::RegisterCustomRegistry( - IMLOperatorRegistry* registry) try { - RETURN_HR_IF(S_OK, registry == nullptr); - -#ifdef USE_DML - auto custom_registries = GetLotusCustomRegistries(registry); - - // Register - for (auto& custom_registry : custom_registries) { - ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry)); - } -#endif USE_DML - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::FlushContext(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::TrimUploadHeap(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::ReleaseCompletedReferences(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -HRESULT STDMETHODCALLTYPE InferenceSession::CopyOneInputAcrossDevices( - const char* input_name, - const OrtValue* orig_mlvalue, - OrtValue** new_mlvalue) try { - auto session_protected_load_accessor = - static_cast(session_.get()); - const onnxruntime::SessionState& sessionState = session_protected_load_accessor->GetSessionState(); - auto temp_mlvalue = std::make_unique(); - ORT_THROW_IF_ERROR(onnxruntime::utils::CopyOneInputAcrossDevices(sessionState, input_name, *orig_mlvalue, *temp_mlvalue.get())); - *new_mlvalue = temp_mlvalue.release(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/WinMLAdapter.h b/winml/adapter/WinMLAdapter.h deleted file mode 100644 index 6062e1e62f45a..0000000000000 --- a/winml/adapter/WinMLAdapter.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "core/session/onnxruntime_c_api.h" - -namespace Windows::AI::MachineLearning::Adapter { -TRACELOGGING_DECLARE_PROVIDER(winml_trace_logging_provider); - -MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ - // model metadata - virtual const char* STDMETHODCALLTYPE author() = 0; - virtual const char* STDMETHODCALLTYPE name() = 0; - virtual const char* STDMETHODCALLTYPE domain() = 0; - virtual const char* STDMETHODCALLTYPE description() = 0; - virtual int64_t STDMETHODCALLTYPE version() = 0; - virtual HRESULT STDMETHODCALLTYPE GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView ** metadata) = 0; - virtual HRESULT STDMETHODCALLTYPE GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView * *features) = 0; - virtual HRESULT STDMETHODCALLTYPE GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView * *features) = 0; -}; - -MIDL_INTERFACE("a848faf6-5a2e-4a7f-b622-cc036f71e28a") IModelProto : IUnknown{ - // this returns a weak ref - virtual onnx::ModelProto* STDMETHODCALLTYPE get() = 0; - // this returns the ownership without touching the reference and forgets about the object - virtual onnx::ModelProto* STDMETHODCALLTYPE detach() = 0; -}; - -MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnknown { - virtual onnxruntime::InferenceSession* STDMETHODCALLTYPE get() = 0; - // the below returns a weak ref , DO NOT RELEASE IT - virtual HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) = 0; - virtual void STDMETHODCALLTYPE RegisterGraphTransformers() = 0; - virtual HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry * registry) = 0; - virtual HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE StartProfiling() = 0; - virtual HRESULT STDMETHODCALLTYPE EndProfiling() = 0; - virtual void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider * dml_provider) = 0; - virtual void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) = 0; - virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0; - virtual HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name, - const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) = 0; -}; - -// The IOrtSessionBuilder offers an abstraction over the creation of -// InferenceSession, that enables the creation of the session based on a device (CPU/DML). -MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown { - - virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions ** options) = 0; - - virtual HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions * options, - IInferenceSession** session, - onnxruntime::IExecutionProvider** provider) = 0; - - virtual HRESULT STDMETHODCALLTYPE Initialize( - IInferenceSession* session, - onnxruntime::IExecutionProvider* provider) = 0; -}; - - -MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown { - - virtual void STDMETHODCALLTYPE EnableDebugOutput() = 0; - - virtual HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( - winml::LearningModel const& model, - IModelProto* p_model_proto, - bool is_float16_supported) = 0; - - // factory method for creating an ortsessionbuilder from a device - virtual HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue, - IOrtSessionBuilder** session_builder) = 0; - - // factory methods for creating model protos - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0; - - // Data types - - // custom ops - virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0; - virtual HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative * operator_provider_native, IMLOperatorRegistry * *registry) = 0; - - // dml ep hooks - virtual void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) = 0; - virtual void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) = 0; - virtual HRESULT STDMETHODCALLTYPE CopyTensor(onnxruntime::IExecutionProvider* provider, OrtValue* src, OrtValue* dst) = 0; - // note: this returns a weak ref - virtual ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider * provider, void* allocation) = 0; - - // schema overrides (dml does this for us) - virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0; - - // proposed adapter. uses the cross plat ABI currencies - virtual HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(onnxruntime::IExecutionProvider * provider, OrtMemoryInfo** memory_info) = 0; - virtual HRESULT STDMETHODCALLTYPE GetProviderAllocator(onnxruntime::IExecutionProvider * provider, OrtAllocator** allocator) = 0; - virtual HRESULT STDMETHODCALLTYPE FreeProviderAllocator(OrtAllocator* allocator) = 0; - virtual HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue * value, OrtMemoryInfo** memory_info) = 0; - virtual HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0; - virtual HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0; - //virtual HRESULT STDMETHODCALLTYPE CreateTensorFromMap(IInspectable * map, OrtValue * *ort_value) = 0; - //virtual HRESULT STDMETHODCALLTYPE CreateTensorFromSequence(IInspectable * sequence, OrtValue * *ort_value) = 0; -}; - -class InferenceSession : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - IInferenceSession> { - -public: - - InferenceSession(onnxruntime::InferenceSession * session); - - onnxruntime::InferenceSession* STDMETHODCALLTYPE get() noexcept override { - return session_.get(); - } - - HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) noexcept override { - // (OrtSession *) are really (InferenceSession *) as well - *out = reinterpret_cast(session_.get()); - return S_OK; - } - - void STDMETHODCALLTYPE RegisterGraphTransformers() override; - HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry* registry) override; - HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) override; - HRESULT STDMETHODCALLTYPE StartProfiling() override; - HRESULT STDMETHODCALLTYPE EndProfiling() override; - void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider* dml_provider) override; - void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) override; - void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) override; - HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name, - const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) override; - - -private: - std::shared_ptr session_; -}; - -} // namespace Windows::AI::MachineLearning::Adapter - -namespace Ort { -// Ort::Allocator is not in the C ABI yet so it will have to be in the WinMLAdapter for now. -// This struct was copied using the Base struct from onnxruntime_cxx_api.h for reference -// Ort::Allocator struct is used as a smart pointer to OrtAllocator. -struct Allocator { - Allocator() { - m_ort_allocator = nullptr; - m_adapter = nullptr; - } - Allocator(winmla::IWinMLAdapter* adapter, OrtAllocator* ort_allocator) : - m_adapter(adapter), m_ort_allocator(ort_allocator) {} - - ~Allocator() { - if (m_adapter != nullptr && m_ort_allocator != nullptr) { - m_adapter->FreeProviderAllocator(m_ort_allocator); - } - } - - operator OrtAllocator*() { return m_ort_allocator; } - operator const OrtAllocator*() const { return m_ort_allocator; } - - OrtAllocator* release() { - OrtAllocator* p = m_ort_allocator; - m_ort_allocator = nullptr; - m_adapter = nullptr; - return p; - } - - OrtAllocator** put() noexcept { - assert(m_ort_allocator == nullptr); - return &m_ort_allocator; - } - - Allocator(const Allocator&) = delete; - Allocator& operator=(const Allocator&) = delete; - Allocator(Allocator&& v) noexcept : - m_adapter{v.m_adapter}, m_ort_allocator{v.m_ort_allocator} { - v.m_adapter = nullptr; - v.m_ort_allocator = nullptr; - } - void operator=(Allocator&& v) noexcept { - if (m_ort_allocator != nullptr && m_adapter != nullptr) { - m_adapter->FreeProviderAllocator(m_ort_allocator); - } - m_adapter = v.m_adapter; - m_ort_allocator = v.m_ort_allocator; - v.m_adapter = nullptr; - v.m_ort_allocator = nullptr; - } - - private: - winmla::IWinMLAdapter* m_adapter; - OrtAllocator* m_ort_allocator; -}; -} // namespace Ort \ No newline at end of file diff --git a/winml/adapter/WinMLAdapterErrors.h b/winml/adapter/WinMLAdapterErrors.h deleted file mode 100644 index 5513842761422..0000000000000 --- a/winml/adapter/WinMLAdapterErrors.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "core/common/status.h" - -inline __declspec(noinline) winrt::hresult_error _winmla_to_hresult() noexcept { - try { - throw; - } catch (winrt::hresult_error const& e) { - return e; - } catch (wil::ResultException const& e) { - return winrt::hresult_error(e.GetErrorCode(), winrt::to_hstring(e.what())); - } catch (std::bad_alloc const&) { - return winrt::hresult_error(E_OUTOFMEMORY); - } catch (std::out_of_range const& e) { - return winrt::hresult_out_of_bounds(winrt::to_hstring(e.what())); - } catch (std::invalid_argument const& e) { - return winrt::hresult_invalid_argument(winrt::to_hstring(e.what())); - } catch (onnxruntime::OnnxRuntimeException const& e) { - StatusCode eStatusCode = static_cast(e.GetStatus().Code()); - return winrt::hresult_error(StatusCodeToHRESULT(eStatusCode), winrt::to_hstring(e.GetStatus().ErrorMessage())); - } catch (std::exception const& e) { - return winrt::hresult_error(E_FAIL, winrt::to_hstring(e.what())); - } catch (...) { - return winrt::hresult_error(E_FAIL); - } -} - -#define WINMLA_CATCH_ALL \ - catch (...) { \ - throw _winmla_to_hresult(); \ - } - -#define WINMLA_CATCH_ALL_COM \ - catch (...) { \ - return _winmla_to_hresult().to_abi(); \ - } - -#define WINMLA_CATCH_ALL_DONOTHING \ - catch (...) { \ - return; \ - } \ No newline at end of file diff --git a/winml/adapter/ZeroCopyInputStreamWrapper.cpp b/winml/adapter/ZeroCopyInputStreamWrapper.cpp deleted file mode 100644 index 1b53326719030..0000000000000 --- a/winml/adapter/ZeroCopyInputStreamWrapper.cpp +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -#include "ZeroCopyInputStreamWrapper.h" - -#include "winrt/Windows.Foundation.h" - -using namespace Windows::AI::MachineLearning; - -// ZeroCopyInputStreamWrapper -ZeroCopyInputStreamWrapper::ZeroCopyInputStreamWrapper( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream) { - winrt::copy_from_abi(stream_, (void*)stream); -} - -bool ZeroCopyInputStreamWrapper::Next( - const void** data, - int* size) { - if (finished_reading_) { - return false; - } - - auto content = stream_.OpenReadAsync().get(); - - wss::Buffer buffer(static_cast(content.Size())); - auto result = content.ReadAsync( - buffer, - buffer.Capacity(), - wss::InputStreamOptions::None) - .get(); - - bytes_ = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); -#ifdef LAYERING_DONE - WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes_, "Model stream is invalid."); - WINML_THROW_IF_FAILED_MSG( - bytes_->Buffer(reinterpret_cast(const_cast(data))), - "Failed to acquire buffer from model stream."); -#else - bytes_->Buffer(reinterpret_cast(const_cast(data))); -#endif - - *size = static_cast(content.Size()); - finished_reading_ = true; - return true; -} - -// BackUp is used when parsing encounters an error and needs to move -// back to the beginning of the erroneous chunk. We don't support random access, -// so we don't have a pointer to move back, but this can also happen for -// decrypted strings since they can have extra memory at the end that -// isn't valid. We don't want to parse non-model related data so we -// don't support this. I'd like to thrown an error here, but protobuf would -// eat that error and terminate the app. So instead we do nothing and handle -// this in LoadFromStream when the protobuf parsing returns false. -void ZeroCopyInputStreamWrapper::BackUp(int count) { - // purposely do nothing. -} - -// the following methods are required by the interface, -// but they aren't actually used by ModelProto parse code, -bool ZeroCopyInputStreamWrapper::Skip( - int count) { -#ifdef LAYERING_DONE - WINML_THROW_HR(E_NOTIMPL); -#endif - return false; -} - -__int64 -ZeroCopyInputStreamWrapper::ByteCount() const { -#ifdef LAYERING_DONE - WINML_THROW_HR(E_NOTIMPL); -#endif - return 0; -} diff --git a/winml/adapter/ZeroCopyInputStreamWrapper.h b/winml/adapter/ZeroCopyInputStreamWrapper.h deleted file mode 100644 index 8938468317606..0000000000000 --- a/winml/adapter/ZeroCopyInputStreamWrapper.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "winrt/Windows.Storage.Streams.h" -#include - -namespace Windows::AI::MachineLearning { -// _ZeroCopyInputStreamWrapper is a helper class that allows a ZeroCopyInputStream, -// which is a protobuf type, to read from an IRandomAccessStreamReference, which is -// a winrt type. -class ZeroCopyInputStreamWrapper : public google::protobuf::io::ZeroCopyInputStream { - public: - ZeroCopyInputStreamWrapper() = delete; - - ZeroCopyInputStreamWrapper( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream); - - // ModelProto load only uses "Next" method - bool - Next( - const void** data, - int* size); - - void - BackUp( - int count); - - bool - Skip( - int count); - - __int64 - ByteCount() const; - - private: - wss::IRandomAccessStreamReference stream_; - bool finished_reading_ = false; - winrt::com_ptr<::Windows::Storage::Streams::IBufferByteAccess> bytes_; -}; - -} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/AbiCustomRegistryImpl.cpp b/winml/adapter/abi_custom_registry_impl.cpp similarity index 98% rename from winml/adapter/AbiCustomRegistryImpl.cpp rename to winml/adapter/abi_custom_registry_impl.cpp index 7242ca121ffe5..00b20cba1b95f 100644 --- a/winml/adapter/AbiCustomRegistryImpl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -5,7 +5,7 @@ #ifdef USE_DML -#include "AbiCustomRegistryImpl.h" +#include "abi_custom_registry_impl.h" namespace Windows::AI::MachineLearning::Adapter { diff --git a/winml/adapter/AbiCustomRegistryImpl.h b/winml/adapter/abi_custom_registry_impl.h similarity index 93% rename from winml/adapter/AbiCustomRegistryImpl.h rename to winml/adapter/abi_custom_registry_impl.h index a07f51cacd067..77b8cba2897d4 100644 --- a/winml/adapter/AbiCustomRegistryImpl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -6,7 +6,7 @@ #ifdef USE_DML #include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" -namespace Windows::AI::MachineLearning::Adapter{ +namespace Windows::AI::MachineLearning::Adapter { // An implementation of AbiCustomRegistry that emits telemetry events when operator kernels or schemas are registered. class AbiCustomRegistryImpl : public AbiCustomRegistry { @@ -38,5 +38,5 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { _In_opt_ IMLOperatorShapeInferrer* shape_inferrer) const noexcept override; }; -} // namespace winrt::Windows::AI::MachineLearning::Adapter +} // namespace Windows::AI::MachineLearning::Adapter #endif USE_DML diff --git a/winml/adapter/winml_adapter_apis.h b/winml/adapter/winml_adapter_apis.h new file mode 100644 index 0000000000000..1c33d5393ef47 --- /dev/null +++ b/winml/adapter/winml_adapter_apis.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "winml_adapter_c_api.h" + +namespace Windows { +namespace AI { +namespace MachineLearning { +namespace Adapter { + +ORT_API(void, ReleaseModel, OrtModel*); +ORT_API(void, ReleaseExecutionProvider, OrtExecutionProvider*); + +ORT_API_STATUS(OverrideSchema); + +// OrtEnv methods +ORT_API_STATUS(EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + +// OrtModel methods +ORT_API_STATUS(CreateModelFromPath, _In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out); +ORT_API_STATUS(CreateModelFromData, _In_ void* data, _In_ size_t size, _Outptr_ OrtModel** out); +ORT_API_STATUS(CloneModel, _In_ const OrtModel* in, _Outptr_ OrtModel** out); +ORT_API_STATUS(ModelGetAuthor, _In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len); +ORT_API_STATUS(ModelGetName, _In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len); +ORT_API_STATUS(ModelGetDomain, _In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len); +ORT_API_STATUS(ModelGetDescription, _In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len); +ORT_API_STATUS(ModelGetVersion, _In_ const OrtModel* model, _Out_ int64_t* version); +ORT_API_STATUS(ModelGetInputCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputName, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_name, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputName, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_name, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputDescription, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_description, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputDescription, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_description, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS(ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS(ModelGetMetadataCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetMetadata, _In_ const OrtModel* model, _Out_ size_t count, _Out_ const char** const key, _Out_ size_t* key_len, _Out_ const char** const value, _Out_ size_t* value_len); +ORT_API_STATUS(ModelEnsureNoFloat16, _In_ const OrtModel* model); + +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* cmd_queue); + +// OrtSession methods +ORT_API_STATUS(CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session); + +//Do not release provider... as there is no release method available +ORT_API_STATUS(SessionGetExecutionProvider, _In_ OrtSession* session, size_t index, _Out_ OrtExecutionProvider** provider); +ORT_API_STATUS(SessionInitialize, _In_ OrtSession* session); +ORT_API_STATUS(SessionLoadAndPurloinModel, _In_ OrtSession* session, _In_ OrtModel* model); + +ORT_API_STATUS(SessionStartProfiling, _In_ OrtEnv* env, _In_ OrtSession* session); +ORT_API_STATUS(SessionEndProfiling, _In_ OrtSession* session); +ORT_API_STATUS(SessionRegisterGraphTransformers, _In_ OrtSession* session); +ORT_API_STATUS(SessionRegisterCustomRegistry, _In_ OrtSession* session, _In_ IMLOperatorRegistry* registry); +ORT_API_STATUS(SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_ const char* const input_name, _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value); + +// Dml methods (TODO need to figure out how these need to move to session somehow...) +ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled); +ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource); +ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource); +ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr); + +// note: this returns a weak ref + +ORT_API_STATUS(GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info); +ORT_API_STATUS(GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator); +ORT_API_STATUS(FreeProviderAllocator, _In_ OrtAllocator* allocator); +ORT_API_STATUS(GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info); + +// ExecutionProvider Methods +ORT_API_STATUS(ExecutionProviderSync, _In_ OrtExecutionProvider* provider); +ORT_API_STATUS(DmlCopyTensor, _In_ OrtExecutionProvider* provider, _In_ OrtValue* src, _In_ OrtValue* dst); +ORT_API_STATUS(CreateCustomRegistry, _Out_ IMLOperatorRegistry** registry); + +ORT_API_STATUS(ValueGetDeviceId, _In_ OrtValue* ort_value, _Out_ int16_t* device_id); +ORT_API_STATUS(SessionGetInputRequiredDeviceId, _In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id); + +} // namespace Adapter +} // namespace MachineLearning +} // namespace AI +} // namespace Windows \ No newline at end of file diff --git a/winml/adapter/winml_adapter_c_api.cpp b/winml/adapter/winml_adapter_c_api.cpp new file mode 100644 index 0000000000000..3ab5645893c8e --- /dev/null +++ b/winml/adapter/winml_adapter_c_api.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "winml_adapter_apis.h" +#include "core/session/ort_apis.h" + +#include +#include +#include + +const OrtApi* GetVersion1Api(); + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +static constexpr WinmlAdapterApi winml_adapter_api_1 = { + // Schema override + &winmla::OverrideSchema, + + // OrtEnv methods + &winmla::EnvConfigureCustomLoggerAndProfiler, + + // OrtTypeInfo Casting methods + &OrtApis::GetDenotationFromTypeInfo, + &OrtApis::CastTypeInfoToMapTypeInfo, + &OrtApis::CastTypeInfoToSequenceTypeInfo, + + // OrtMapTypeInfo Accessors + &OrtApis::GetMapKeyType, + &OrtApis::GetMapValueType, + + // OrtSequenceTypeInfo Accessors + &OrtApis::GetSequenceElementType, + + // OrtModel methods + &winmla::CreateModelFromPath, + &winmla::CreateModelFromData, + &winmla::CloneModel, + &winmla::ModelGetAuthor, + &winmla::ModelGetName, + &winmla::ModelGetDomain, + &winmla::ModelGetDescription, + &winmla::ModelGetVersion, + &winmla::ModelGetInputCount, + &winmla::ModelGetOutputCount, + &winmla::ModelGetInputName, + &winmla::ModelGetOutputName, + &winmla::ModelGetInputDescription, + &winmla::ModelGetOutputDescription, + &winmla::ModelGetInputTypeInfo, + &winmla::ModelGetOutputTypeInfo, + &winmla::ModelGetMetadataCount, + &winmla::ModelGetMetadata, + &winmla::ModelEnsureNoFloat16, + + // OrtSessionOptions methods + &OrtSessionOptionsAppendExecutionProvider_CPU, + &winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, + + // OrtSession methods + &winmla::CreateSessionWithoutModel, + &winmla::SessionGetExecutionProvider, + &winmla::SessionInitialize, + &winmla::SessionRegisterGraphTransformers, + &winmla::SessionRegisterCustomRegistry, + &winmla::SessionLoadAndPurloinModel, + &winmla::SessionStartProfiling, + &winmla::SessionEndProfiling, + &winmla::SessionCopyOneInputAcrossDevices, + + // Dml methods (TODO need to figure out how these need to move to session somehow...) + &winmla::DmlExecutionProviderSetDefaultRoundingMode, + &winmla::DmlExecutionProviderFlushContext, + &winmla::DmlExecutionProviderTrimUploadHeap, + &winmla::DmlExecutionProviderReleaseCompletedReferences, + &winmla::DmlCreateGPUAllocationFromD3DResource, + &winmla::DmlFreeGPUAllocation, + &winmla::DmlGetD3D12ResourceFromAllocation, + &winmla::DmlCopyTensor, + + &winmla::GetProviderMemoryInfo, + &winmla::GetProviderAllocator, + &winmla::FreeProviderAllocator, + &winmla::GetValueMemoryInfo, + + &winmla::ExecutionProviderSync, + + &winmla::CreateCustomRegistry, + + &winmla::ValueGetDeviceId, + &winmla::SessionGetInputRequiredDeviceId, + + // Release + &winmla::ReleaseModel, + &OrtApis::ReleaseMapTypeInfo, + &OrtApis::ReleaseSequenceTypeInfo}; + +const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(const OrtApi* ort_api) NO_EXCEPTION { + if (GetVersion1Api() == ort_api) { + return &winml_adapter_api_1; + } + + return nullptr; +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_c_api.h b/winml/adapter/winml_adapter_c_api.h new file mode 100644 index 0000000000000..7f2e17259e0be --- /dev/null +++ b/winml/adapter/winml_adapter_c_api.h @@ -0,0 +1,469 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(ExecutionProvider); + +struct WinmlAdapterApi; +typedef struct WinmlAdapterApi WinmlAdapterApi; + +struct ID3D12Resource; +struct ID3D12Device; +struct ID3D12CommandQueue; +struct IMLOperatorRegistry; + +// TODO: Must match onnxruntime::profiling::EventRecord +enum OrtProfilerEventCategory { + SESSION_EVENT = 0, + NODE_EVENT, + EVENT_CATEGORY_MAX +}; + +struct OrtProfilerEventRecord { + OrtProfilerEventCategory category_; + const char* category_name_; + int64_t duration_; + int64_t time_span_; + const char* event_name_; + int32_t process_id_; + int32_t thread_id_; + const char* op_name_; + const char* execution_provider_; +}; + +typedef void(ORT_API_CALL* OrtProfilingFunction)(const OrtProfilerEventRecord* event_record); + +struct WinmlAdapterApi { + /** + * OverrideSchema + * This api is used to override schema inference functions for a variety of ops across opsets. + * This exists because certain ops were failing to infer schemas and caused performance + * issues for DML as it was forced to create resources during evaluation. + * This can be removed when schema inference functions have been updated. + */ + OrtStatus*(ORT_API_CALL* OverrideSchema)() NO_EXCEPTION; + + /** + * EnvConfigureCustomLoggerAndProfiler + * This api is used to add a custom logger and profiler to the ors environment. + * This exists because existing methods on the c-abi to create the environment only support a custom logger. + * Since WinML hooks the profiler events, we expose the profiler and an associated profiling function. + */ + OrtStatus*(ORT_API_CALL* EnvConfigureCustomLoggerAndProfiler)(_In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out)NO_EXCEPTION; + + /** + * GetDenotationFromTypeInfo + * This api augments OrtTypeInfo to return denotations on the type. + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + */ + OrtStatus*(ORT_API_CALL* GetDenotationFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len)NO_EXCEPTION; + + // OrtTypeInfo Casting methods + + /** + * CastTypeInfoToMapTypeInfo + * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + OrtStatus*(ORT_API_CALL* CastTypeInfoToMapTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out)NO_EXCEPTION; + + /** + * CastTypeInfoToSequenceTypeInfo + * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence. + * The OrtSequenceTypeInfo has additional information about the sequence's element type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + OrtStatus*(ORT_API_CALL* CastTypeInfoToSequenceTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out)NO_EXCEPTION; + + // OrtMapTypeInfo Accessors + + /** + * GetMapKeyType + * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* GetMapKeyType)(_In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; + + /** + * GetMapValueType + * This api augments get the value type of a map. + */ + OrtStatus*(ORT_API_CALL* GetMapValueType)(_In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + // OrtSequenceTypeInfo Accessors + + /** + * GetSequenceElementType + * This api augments get the element type of a sequence. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* GetSequenceElementType)(_In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + // OrtModel methods + + /** + * CreateModelFromPath + * This api creates an OrtModel based on a specified model path. + * There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* CreateModelFromPath)(_In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * CreateModelFromData + * This api creates an OrtModel from a buffer. + * There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* CreateModelFromData)(_In_ void* data, _In_ size_t size, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * CloneModel + * This api copies the OrtModel along with its internal proto buffer and cached metadata. + * The OrtSession type expects to own the model proto buffer. + * WinML uses this to yield copies of the model proto held by OrtModel to OrtSession. + */ + OrtStatus*(ORT_API_CALL* CloneModel)(_In_ const OrtModel* in, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * ModelGetAuthor + * This api gets the model author from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetAuthor)(_In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetName + * This api gets the model name from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetName)(_In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetDomain + * This api gets the model domain from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetDomain)(_In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetDescription + * This api gets the model description from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetDescription)(_In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetVersion + * This api gets the model version from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetVersion)(_In_ const OrtModel* model, _Out_ int64_t* version)NO_EXCEPTION; + + /** + * ModelGetInputCount + * This api gets the number of inputs from the OrtModel. It closely matches the API of a similar name similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputCount + * This api gets the number of outputs from the OrtModel. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputName + * This api gets the input name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputName)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_name, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputName + * This api gets the output name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputName)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_name, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputDescription + * This api gets the input description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputDescription)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_description, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputDescription + * This api gets the output description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputDescription)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_description, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputTypeInfo + * This api gets the input OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputTypeInfo)(_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + /** + * ModelGetOutputTypeInfo + * This api gets the output OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputTypeInfo)(_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + /** + * ModelGetMetadataCount + * This api gets the number of metadata entries from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetMetadataCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetMetadata + * This api gets the model metadata from the OrtModel. + * This is used by WinML to deduce whether model input and output formats are supported by the WinML tensorization code paths. + */ + OrtStatus*(ORT_API_CALL* ModelGetMetadata)(_In_ const OrtModel* model, _Out_ size_t count, _Out_ const char** const key, _Out_ size_t* key_len, _Out_ const char** const value, _Out_ size_t* value_len)NO_EXCEPTION; + + /** + * ModelEnsureNoFloat16 + * This api checks whether the model requires float 16 support. + * This is used by WinML to fail gracefully when float 16 support is not available on the device. + * + * Can this API be moved into the EP during session initialization. Currently we do an early fp16 check to avoid initialization when it is not supported. + */ + OrtStatus*(ORT_API_CALL* ModelEnsureNoFloat16)(_In_ const OrtModel* model)NO_EXCEPTION; + + // OrtSessionOptions methods + + /** + * OrtSessionOptionsAppendExecutionProvider_CPU + * This api is used to add the cpu EP to OrtSessionOptions so that WinML Gpu session are configures with CPU fallback. + */ + OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_CPU)(_In_ OrtSessionOptions* options, int use_arena)NO_EXCEPTION; + + /** + * OrtSessionOptionsAppendExecutionProvider_DML + * This api is used to add the DML EP to OrtSessionOptions. + */ + OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_DML)(_In_ OrtSessionOptions* options, ID3D12Device* device, ID3D12CommandQueue* queue)NO_EXCEPTION; + + // OrtSession methods + + /** + * CreateSessionWithoutModel + * This api is used to create a Session that is completely uninitialized. While there are other Session creation APIs in the + * c-abi, WinML uses this so that it can perform optimizations prior to loading the model, and initializing. + * Moreover, WinML needs a new api to support the OrtModel type, and prevent the parsing model protobufs again on session creation. + */ + OrtStatus*(ORT_API_CALL* CreateSessionWithoutModel)(_In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session)NO_EXCEPTION; + + /** + * SessionGetExecutionProvider + * This api is used to get a handle to an OrtExecutionProvider. + * Currently WinML uses this to talk directly to the DML EP and configure settings on it. + */ + OrtStatus*(ORT_API_CALL* SessionGetExecutionProvider)(_In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** provider)NO_EXCEPTION; + + /** + * SessionInitialize + * This api is used to initialize an OrtSession. This is one component of creating a usable OrtSession, and is a part of CreateSession in the c-abi. + * Currently WinML uses this to finalize session creation, after configuring a variety of properties on the OrtSession. + */ + OrtStatus*(ORT_API_CALL* SessionInitialize)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionRegisterGraphTransformers + * This api is used to enable DML specific graph transformations on an OrtSession. + * + * Ideally these transformations should be configured by the contract between the runtime and the EP and not overridden by WinML. + */ + OrtStatus*(ORT_API_CALL* SessionRegisterGraphTransformers)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionRegisterCustomRegistry + * This api is used to support custom operators as they were shipped in WinML RS5. + */ + OrtStatus*(ORT_API_CALL* SessionRegisterCustomRegistry)(_In_ OrtSession* session, _In_ IMLOperatorRegistry* registry)NO_EXCEPTION; + + /** + * SessionLoadAndPurloinModel + * This api is used to load an OrtModel into an OrtSession. + * + * Don't free the 'out' value as this API will defunct and release the OrtModel internally. + */ + OrtStatus*(ORT_API_CALL* SessionLoadAndPurloinModel)(_In_ OrtSession* session, _In_ OrtModel* model)NO_EXCEPTION; + + /** + * SessionStartProfiling + * This api is used to start profiling OrtSession. The existing mechanism only allows configuring profiling at session creation. + * + * WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to. + */ + OrtStatus*(ORT_API_CALL* SessionStartProfiling)(_In_ OrtEnv* env, _In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionEndProfiling + * This api is used to end profiling OrtSession. The existing mechanism only allows configuring profiling at session creation. + * + * WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to. + */ + OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionCopyOneInputAcrossDevices + * This api is used to copy and create an OrtValue input to prepare the input on the correct device. + * + * WinML uses this to copy gpu device OrtValues to the CPU and vice-versa. + */ + OrtStatus*(ORT_API_CALL* SessionCopyOneInputAcrossDevices)(_In_ OrtSession* session, _In_ const char* const input_name, _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value)NO_EXCEPTION; + + // Dml methods (TODO need to figure out how these need to move to session somehow...) + + /** + * DmlExecutionProviderSetDefaultRoundingMode + * This api is used to configure the DML EP to turn on/off rounding. + * + * WinML uses this to disable rounding during session initialization and then enables it again post initialization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled)NO_EXCEPTION; + + /** + * DmlExecutionProviderFlushContext + * This api is used to flush the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlExecutionProviderTrimUploadHeap + * This api is used to trim the upload heap in the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderTrimUploadHeap)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlExecutionProviderReleaseCompletedReferences + * This api is used to release completed references after first run the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlCreateGPUAllocationFromD3DResource + * This api is used to create a DML EP input based on a user specified d3d12 resource. + * + * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. + */ + OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void** dml_resource)NO_EXCEPTION; + + /** + * DmlFreeGPUAllocation + * This api is used free the DML EP input created by DmlCreateGPUAllocationFromD3DResource. + * + * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. + */ + OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION; + + /** + * DmlGetD3D12ResourceFromAllocation + * This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData. + * + * WinML uses this in the image feature path to get the d3d resource and perform and tensorization on inputs directly into the allocated d3d12 resource. + */ + OrtStatus*(ORT_API_CALL* DmlGetD3D12ResourceFromAllocation)(_In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource)NO_EXCEPTION; + + /** + * DmlCopyTensor + * This api is used copy a tensor allocated by the DML EP Allocator to the CPU. + * + * WinML uses this when graphs are evaluated with DML, and their outputs remain on the GPU but need to be copied back to the CPU. + */ + OrtStatus*(ORT_API_CALL* DmlCopyTensor)(_In_ OrtExecutionProvider* provider, _In_ OrtValue* src, _In_ OrtValue* dst)NO_EXCEPTION; + + /** + * GetProviderMemoryInfo + * This api gets the memory info object associated with an EP. + * + * WinML uses this to manage caller specified D3D12 inputs/outputs. It uses the memory info here to call DmlCreateGPUAllocationFromD3DResource. + */ + OrtStatus*(ORT_API_CALL* GetProviderMemoryInfo)(_In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info)NO_EXCEPTION; + + /** + * GetProviderAllocator + * This api gets associated allocator used by a provider. + * + * WinML uses this to create tensors, and needs to hold onto the allocator for the duration of the associated value's lifetime. + */ + OrtStatus*(ORT_API_CALL* GetProviderAllocator)(_In_ OrtExecutionProvider* provider, OrtAllocator** allocator)NO_EXCEPTION; + + /** + * FreeProviderAllocator + * This api frees an allocator. + * + * WinML uses this to free the associated allocator for an ortvalue when creating tensors. + * Internally this derefs a shared_ptr. + */ + OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION; + + /** + * GetValueMemoryInfo + * This api gets the memory info of an OrtValue. + * + * WinML uses this to determine if an OrtValue is allocated on the Cpu or elsewhere. + */ + OrtStatus*(ORT_API_CALL* GetValueMemoryInfo)(const OrtValue* value, OrtMemoryInfo** memory_info)NO_EXCEPTION; + + /** + * ExecutionProviderSync + * This api syncs the EP. + * + * WinML uses this to sync EP inputs/outputs directly. + */ + OrtStatus*(ORT_API_CALL* ExecutionProviderSync)(_In_ OrtExecutionProvider* provider)NO_EXCEPTION; + + /** + * CreateCustomRegistry + * This api creates a custom registry that callers can populate with cusom ops. + * + * WinML uses this to support custom ops. + */ + OrtStatus*(ORT_API_CALL* CreateCustomRegistry)(_Out_ IMLOperatorRegistry** registry)NO_EXCEPTION; + + /** + * ValueGetDeviceId + * This api returns the device id of the OrtValue. + * + * WinML uses this to determine if an OrtValue is created on the needed device. + */ + OrtStatus*(ORT_API_CALL* ValueGetDeviceId)(_In_ OrtValue* ort_value, _Out_ int16_t* device_id)NO_EXCEPTION; + + /** + * SessionGetInputRequiredDeviceId + * This api returns the required device id for a model input. + * + * WinML uses this to determine if an OrtValue is created on the needed device. + */ + OrtStatus*(ORT_API_CALL* SessionGetInputRequiredDeviceId)(_In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id)NO_EXCEPTION; + + ORT_CLASS_RELEASE(Model); + ORT_CLASS_RELEASE(MapTypeInfo); + ORT_CLASS_RELEASE(SequenceTypeInfo); +}; diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp new file mode 100644 index 0000000000000..ddbd03475bc9f --- /dev/null +++ b/winml/adapter/winml_adapter_dml.cpp @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#ifdef USE_DML +#include "core/session/abi_session_options_impl.h" +#include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#endif // USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +#ifdef USE_DML +Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { + // Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll + wil::unique_hmodule dmlDll(LoadLibraryW(L"DirectML.dll")); + THROW_LAST_ERROR_IF(!dmlDll); + + auto dmlCreateDevice1Fn = reinterpret_cast( + GetProcAddress(dmlDll.get(), "DMLCreateDevice1")); + THROW_LAST_ERROR_IF(!dmlCreateDevice1Fn); + + DML_CREATE_DEVICE_FLAGS dmlFlags = DML_CREATE_DEVICE_FLAG_NONE; + + // Enable the DML debug layer in DEBUG builds, if the D3D12 debug layer is also enabled +#if _DEBUG + Microsoft::WRL::ComPtr d3d12DebugDevice; + if (SUCCEEDED(d3d12Device->QueryInterface(IID_PPV_ARGS(&d3d12DebugDevice)))) { + d3d12DebugDevice = nullptr; + dmlFlags |= DML_CREATE_DEVICE_FLAG_DEBUG; + } +#endif // USE_DML + + Microsoft::WRL::ComPtr dmlDevice; + THROW_IF_FAILED(dmlCreateDevice1Fn(d3d12Device, dmlFlags, DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dmlDevice))); + + // Keep DirectML.dll loaded by leaking the handle. This is equivalent behavior to if we delay-loaded the DLL. + dmlDll.release(); + + return dmlDevice; +} + +namespace onnxruntime { +void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode); +} + +#endif // USE_DML + +ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, + ID3D12Device* d3d_device, ID3D12CommandQueue* queue) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_device = CreateDmlDevice(d3d_device); + if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) { + return status; + } + auto factory = options->provider_factories.back().get(); + + // OnnxRuntime uses the default rounding mode when calling the session's allocator. + // During initialization, OnnxRuntime allocates weights, which are permanent across session + // lifetime and can be large, so shouldn't be rounded. + // So we create the provider with rounding disabled, and expect the caller to enable it after. + onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::SetDefaultRoundingMode(dml_provider_internal, is_enabled ? AllocatorRoundingMode::Enabled : AllocatorRoundingMode::Disabled); +#endif + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::FlushContext(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::TrimUploadHeap(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::ReleaseCompletedReferences(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + *dml_resource = Dml::CreateGPUAllocationFromD3DResource(pResource); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* dml_provider, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + *d3d_resource = + Dml::GetD3D12ResourceFromAllocation( + dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), + allocation); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlFreeGPUAllocation, _In_ void* ptr) { + API_IMPL_BEGIN +#ifdef USE_DML + Dml::FreeGPUAllocation(ptr); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provider, _In_ OrtValue* src, _In_ OrtValue* dst) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + auto status = Dml::CopyTensor(dml_provider_internal, *(src->GetMutable()), *(dst->GetMutable())); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; +#else + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Out of memory"); +#endif // USE_DML USE_DML + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp new file mode 100644 index 0000000000000..4aba907e4cb86 --- /dev/null +++ b/winml/adapter/winml_adapter_environment.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_env.h" + +#ifdef USE_DML +#include "abi_custom_registry_impl.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" +#endif USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +class WinmlAdapterLoggingWrapper : public LoggingWrapper { + public: + WinmlAdapterLoggingWrapper(OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param) : LoggingWrapper(logging_function, logger_param), + profiling_function_(profiling_function) { + } + + void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override { + if (profiling_function_) { + OrtProfilerEventRecord ort_event_record = {}; + ort_event_record.category_ = static_cast(event_record.cat); + ort_event_record.category_name_ = onnxruntime::profiling::event_categor_names_[event_record.cat]; + ort_event_record.duration_ = event_record.dur; + ort_event_record.event_name_ = event_record.name.c_str(); + ort_event_record.execution_provider_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["provider"].c_str() : nullptr; + ort_event_record.op_name_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["op_name"].c_str() : nullptr; + ort_event_record.process_id_ = event_record.pid; + ort_event_record.thread_id_ = event_record.tid; + ort_event_record.time_span_ = event_record.ts; + + profiling_function_(&ort_event_record); + } + } + + private: + OrtProfilingFunction profiling_function_{}; +}; + +ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, + _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, + _In_ const char* logid, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN + std::string name = logid; + std::unique_ptr logger = onnxruntime::make_unique(logging_function, profiling_function, logger_param); + + // Clear the logging manager, since only one default instance of logging manager can exist at a time. + env->SetLoggingManager(nullptr); + + auto winml_logging_manager = std::make_unique(std::move(logger), + static_cast(default_warning_level), + false, + onnxruntime::logging::LoggingManager::InstanceType::Default, + &name); + + // Set a new default logging manager + env->SetLoggingManager(std::move(winml_logging_manager)); + return nullptr; + API_IMPL_END +} + +// Override select shape inference functions which are incomplete in ONNX with versions that are complete, +// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being +// deferred until first evaluation. It also prevents a situation where inference functions in externally +// registered schema are reachable only after upstream schema have been revised in a later OS release, +// which would be a compatibility risk. +ORT_API_STATUS_IMPL(winmla::OverrideSchema) { + API_IMPL_BEGIN +#ifdef USE_DML + static std::once_flag schema_override_once_flag; + std::call_once(schema_override_once_flag, []() { + SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); + }); +#endif USE_DML. + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_execution_provider.cpp b/winml/adapter/winml_adapter_execution_provider.cpp new file mode 100644 index 0000000000000..a38af2af931c3 --- /dev/null +++ b/winml/adapter/winml_adapter_execution_provider.cpp @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +struct OrtAllocatorWrapper : public OrtAllocator { + public: + OrtAllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + return static_cast(this_)->impl_->Alloc(size); + } + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + return static_cast(this_)->impl_->Free(p); + } + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + return &(static_cast(this_)->impl_->Info()); + } + + private: + onnxruntime::AllocatorPtr impl_; +}; + +ORT_API_STATUS_IMPL(winmla::ExecutionProviderSync, _In_ OrtExecutionProvider* provider) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + execution_provider->Sync(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + auto allocator_ptr = execution_provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); + *allocator = new (std::nothrow) OrtAllocatorWrapper(allocator_ptr); + if (*allocator == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + + auto allocator = execution_provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); + + const auto& info = allocator->Info(); + *memory_info = new (std::nothrow) OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); + if (*memory_info == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator) { + API_IMPL_BEGIN + delete static_cast(allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info) { + API_IMPL_BEGIN + const auto& tensor = value->Get(); + auto info = tensor.Location(); + *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); + if (*memory_info == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp new file mode 100644 index 0000000000000..6e4d22588aec3 --- /dev/null +++ b/winml/adapter/winml_adapter_model.cpp @@ -0,0 +1,429 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_model.h" + +#include "winml_adapter_c_api.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "core/framework/onnxruntime_typeinfo.h" + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +static std::vector GetInitializers(const onnx::ModelProto& model_proto) { + std::vector initializers; + auto& graph = model_proto.graph(); + auto& graph_initializers = graph.initializer(); + for (auto& initializer : graph_initializers) { + initializers.push_back(initializer.name().c_str()); + } + return initializers; +} + +static std::vector GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { + auto initializers = GetInitializers(model_proto); + + std::vector inputs_without_initializers; + auto& graph = model_proto.graph(); + auto& inputs = graph.input(); + for (auto& input : inputs) { + if (input.has_name() && input.has_type()) { + auto found_it = std::find_if( + std::begin(initializers), + std::end(initializers), + [&](auto& initializer) { + return std::strcmp(initializer, input.name().c_str()) == 0; + }); + + auto is_initializer = found_it != std::end(initializers); + if (!is_initializer) { + inputs_without_initializers.push_back(&input); + } + } + } + return inputs_without_initializers; +} + +static std::vector GetOutputs(const onnx::ModelProto& model_proto) { + std::vector outputs_with_name; + auto& graph = model_proto.graph(); + auto& outputs = graph.output(); + for (auto& output : outputs) { + if (output.has_name() && output.has_type()) { + outputs_with_name.push_back(&output); + } + } + return outputs_with_name; +} + +class ModelInfo { + public: + ModelInfo(const onnx::ModelProto* model_proto) { + Initialize(model_proto); + } + + public: + // model metadata + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::vector> model_metadata_; + std::vector input_features_; + std::vector output_features_; + bool requires_float16_support_; + + private: + void Initialize(const onnx::ModelProto* model_proto) { + for (auto& prop : model_proto->metadata_props()) { + model_metadata_.push_back(std::make_pair(prop.key(), prop.value())); + } + + input_features_ = GetInputsWithoutInitializers(*model_proto); + output_features_ = ::GetOutputs(*model_proto); + + auto has_producer_name = model_proto->has_producer_name(); + author_ = has_producer_name ? model_proto->producer_name() : ""; + + auto has_domain = model_proto->has_domain(); + domain_ = has_domain ? model_proto->domain() : ""; + + auto has_graph = model_proto->has_graph(); + auto graph_has_name = model_proto->graph().has_name(); + auto is_name_available = has_graph && graph_has_name; + name_ = is_name_available ? model_proto->graph().name() : ""; + + auto has_description = model_proto->has_doc_string(); + description_ = has_description ? model_proto->doc_string() : ""; + + auto has_version = model_proto->has_model_version(); + version_ = has_version ? model_proto->model_version() : 0; + } +}; + +OrtModel::OrtModel(std::unique_ptr model_proto) : model_proto_(std::move(model_proto)), + model_info_(std::make_unique(model_proto_.get())) { +} + +// factory methods for creating an ort model from a path +static OrtStatus* CreateModelProto(const char* path, std::unique_ptr& out) { + int file_descriptor; + _set_errno(0); // clear errno + _sopen_s( + &file_descriptor, + path, + O_RDONLY | _O_SEQUENTIAL | _O_BINARY, + _SH_DENYWR, + _S_IREAD | _S_IWRITE); + + errno_t err = 0; + _get_errno(&err); + if (err == ENOENT) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!"); + } + + if (0 > file_descriptor) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!"); + } + + google::protobuf::io::FileInputStream stream(file_descriptor); + stream.SetCloseOnDelete(true); + + auto model_proto = std::unique_ptr(new onnx::ModelProto()); + + auto parse_succeeded = model_proto->ParseFromZeroCopyStream(&stream); + if (!parse_succeeded) { + return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model file!"); + } + + out = std::move(model_proto); + + return S_OK; +} + +OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model) { + ORT_UNUSED_PARAMETER(len); + + std::unique_ptr model_proto; + + if (auto status = CreateModelProto(path, model_proto)) { + return status; + } + + return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); +} + +OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** model) { + auto model_proto = std::unique_ptr(new onnx::ModelProto()); + + auto parse_succeeded = model_proto->ParseFromArray(data, static_cast(len)); + if (!parse_succeeded) { + return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model stream!"); + } + + return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); +} + +OrtStatus* OrtModel::CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model) { + *model = new (std::nothrow) OrtModel(std::move(model_proto)); + if (*model == nullptr) { + return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Engine failed to create a model!"); + } + + return nullptr; +} + +const ModelInfo* OrtModel::UseModelInfo() const { + return model_info_.get(); +} + +const ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const { + return model_proto_.get(); +} + +std::unique_ptr OrtModel::DetachModelProto() { + return std::move(model_proto_); +} + +ORT_API_STATUS_IMPL(winmla::CreateModelFromPath, const char* model_path, size_t size, OrtModel** out) { + API_IMPL_BEGIN + if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateModelFromData, void* data, size_t size, OrtModel** out) { + API_IMPL_BEGIN + if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CloneModel, const OrtModel* in, OrtModel** out) { + API_IMPL_BEGIN + auto model_proto_copy = std::make_unique(*in->UseModelProto()); + if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, const OrtModel* model, const char** const author, size_t* len) { + API_IMPL_BEGIN + *author = model->UseModelInfo()->author_.c_str(); + *len = model->UseModelInfo()->author_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetName, const OrtModel* model, const char** const name, size_t* len) { + API_IMPL_BEGIN + *name = model->UseModelInfo()->name_.c_str(); + *len = model->UseModelInfo()->name_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetDomain, const OrtModel* model, const char** const domain, size_t* len) { + API_IMPL_BEGIN + *domain = model->UseModelInfo()->domain_.c_str(); + *len = model->UseModelInfo()->domain_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetDescription, const OrtModel* model, const char** const description, size_t* len) { + API_IMPL_BEGIN + *description = model->UseModelInfo()->description_.c_str(); + *len = model->UseModelInfo()->description_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetVersion, const OrtModel* model, int64_t* version) { + API_IMPL_BEGIN + *version = model->UseModelInfo()->version_; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->model_metadata_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetMetadata, const OrtModel* model, size_t count, const char** const key, + size_t* key_len, const char** const value, size_t* value_len) { + API_IMPL_BEGIN + *key = model->UseModelInfo()->model_metadata_[count].first.c_str(); + *key_len = model->UseModelInfo()->model_metadata_[count].first.size(); + *value = model->UseModelInfo()->model_metadata_[count].second.c_str(); + *value_len = model->UseModelInfo()->model_metadata_[count].second.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->input_features_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->output_features_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputName, const OrtModel* model, size_t index, + const char** input_name, size_t* count) { + API_IMPL_BEGIN + *input_name = model->UseModelInfo()->input_features_[index]->name().c_str(); + *count = model->UseModelInfo()->input_features_[index]->name().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t index, + const char** output_name, size_t* count) { + API_IMPL_BEGIN + *output_name = model->UseModelInfo()->output_features_[index]->name().c_str(); + *count = model->UseModelInfo()->output_features_[index]->name().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, + const char** input_description, size_t* count) { + API_IMPL_BEGIN + *input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str(); + *count = model->UseModelInfo()->input_features_[index]->doc_string().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, + const char** output_description, size_t* count) { + API_IMPL_BEGIN + *output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str(); + *count = model->UseModelInfo()->output_features_[index]->doc_string().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN + if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN + if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) { + API_IMPL_BEGIN + auto model_info = model->UseModelInfo(); + auto model_proto = model->UseModelProto(); + auto& graph = model_proto->graph(); + + // The model will not contain fp16 operations if: + // 1. The model has no fp16 inputs + // 2. The model has no fp16 initializers + // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator + // 4. The model does not have any fp16 outputs + + // 1. Ensure that The model has no fp16 inputs + for (auto input : model_info->input_features_) { + auto& type = input->type(); + if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { + auto& tensor_type = type.tensor_type(); + if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << input->name() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + + // 2. Ensure that the model has no fp16 initializers + for (int i = 0; i < graph.node_size(); i++) { + auto node = graph.node(i); + if (node.op_type() == "Cast" && node.domain().empty()) { + for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { + auto attribute = node.attribute(attribIndex); + if (attribute.name() == "to") { + if (attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << node.name().c_str() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + } + } + + // 3. Ensure that the model does not create any fp16 intermediary + // tensors via the Cast (to float16) operator + for (int i = 0; i < graph.initializer_size(); i++) { + auto initializer = graph.initializer(i); + if (initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << initializer.name().c_str() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + + // 4. Ensure that the model does not have any fp16 outputs + for (auto output : model_info->output_features_) { + auto& type = output->type(); + if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { + auto& tensor_type = type.tensor_type(); + if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << output->name() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + return nullptr; + API_IMPL_END +} + +ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_model.h b/winml/adapter/winml_adapter_model.h new file mode 100644 index 0000000000000..df245f75c7941 --- /dev/null +++ b/winml/adapter/winml_adapter_model.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "winml_adapter_c_api.h" +#include +#include "core/graph/onnx_protobuf.h" + +class ModelInfo; + +struct OrtModel { + public: + static OrtStatus* CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model); + static OrtStatus* CreateOrtModelFromData(void* data, size_t len, OrtModel** model); + static OrtStatus* CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model); + const ModelInfo* UseModelInfo() const; + + const onnx::ModelProto* UseModelProto() const; + std::unique_ptr DetachModelProto(); + + private: + OrtModel(std::unique_ptr model_proto); + OrtModel(const OrtModel& other) = delete; + OrtModel& operator=(const OrtModel& other) = delete; + + private: + std::unique_ptr model_proto_; + std::unique_ptr model_info_; +}; diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp new file mode 100644 index 0000000000000..1a65f1e885677 --- /dev/null +++ b/winml/adapter/winml_adapter_session.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#include "core/session/inference_session.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/onnxruntime_env.h" + +#include "winml_adapter_model.h" +#include "core/framework/utils.h" + +#ifdef USE_DML +#include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" +#include "abi_custom_registry_impl.h" +#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" +#endif USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +// ORT intentionally requires callers derive from their session class to access +// the protected methods used below. +class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession { + public: + onnxruntime::common::Status + Load(std::unique_ptr p_model_proto) { + return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); + } + const onnxruntime::SessionState& GetSessionState() { + return *session_state_; + } +}; + +ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session) { + API_IMPL_BEGIN + std::unique_ptr inference_session; + try { + // Create the inference session + inference_session = std::make_unique(options->value, env->GetLoggingManager()); + } catch (const std::exception& e) { + return OrtApis::CreateStatus(ORT_FAIL, e.what()); + } + + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + if (provider->Type() == onnxruntime::kDmlExecutionProvider) { + if (options->value.enable_mem_pattern) { + // TODO Instead of returning an error, should we set mem pattern to false here and log a warning saying so? + // Doing so would be inconsistent with the Python API that doesn't go through this code path. + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Mem pattern should be disabled when using DML execution provider."); + } + if (options->value.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sequential execution should be enabled when using DML execution provider."); + } + } + provider_list.push_back(std::move(provider)); + } + } + + Status status; + if (options) { + if (!options->custom_op_domains_.empty()) { + status = inference_session->AddCustomOpDomains(options->custom_op_domains_); + if (!status.IsOK()) + return onnxruntime::ToOrtStatus(status); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + inference_session->RegisterExecutionProvider(std::move(provider)); + } + } + + *session = reinterpret_cast(inference_session.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionGetExecutionProvider, _In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** ort_provider) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const auto& session_state = session_protected_load_accessor->GetSessionState(); + auto& provider_id = session_state.GetExecutionProviders().GetIds().at(index); + const auto& provider = session_state.GetExecutionProviders().Get(provider_id); + + *ort_provider = const_cast(reinterpret_cast(provider)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionInitialize, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto status = inference_session->Initialize(); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionLoadAndPurloinModel, _In_ OrtSession* session, _In_ OrtModel* model) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + + auto status = session_protected_load_accessor->Load(model->DetachModelProto()); + + ReleaseModel(model); + + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionStartProfiling, _In_ OrtEnv* env, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + inference_session->StartProfiling(&env->GetLoggingManager()->DefaultLogger()); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionEndProfiling, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + inference_session->EndProfiling(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionRegisterGraphTransformers, _In_ OrtSession* session) { + API_IMPL_BEGIN +#ifdef USE_DML + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + + // Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT + GraphTransformerHelpers::RegisterGraphTransformers(inference_session); +#endif USE_DML + return nullptr; + API_IMPL_END +} + +inline std::list> +GetLotusCustomRegistries(IMLOperatorRegistry* registry) { + if (registry != nullptr) { + // Down-cast to the concrete type. + // The only supported input is the AbiCustomRegistry type. + // Other implementations of IMLOperatorRegistry are forbidden. + auto abi_custom_registry = + static_cast(registry); + + // Get the ORT registry + return abi_custom_registry->GetRegistries(); + } + return {}; +} + +ORT_API_STATUS_IMPL(winmla::SessionRegisterCustomRegistry, _In_ OrtSession* session, _In_ IMLOperatorRegistry* registry) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto custom_registries = GetLotusCustomRegistries(registry); + + // Register + for (auto& custom_registry : custom_registries) { + ORT_THROW_IF_ERROR(inference_session->RegisterCustomRegistry(custom_registry)); + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateCustomRegistry, _Out_ IMLOperatorRegistry** registry) { + API_IMPL_BEGIN + auto impl = wil::MakeOrThrow(); + *registry = impl.Detach(); + return nullptr; + API_IMPL_END +} + +static OrtDevice GetSessionGetInputDevice(_In_ OrtSession* session, _In_ const char* const input_name) { + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const onnxruntime::SessionState& session_state = session_protected_load_accessor->GetSessionState(); + + std::vector node_info_vec; + session_state.GetInputNodeInfo(input_name, node_info_vec); + const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine + return *node_info.device; +} + +ORT_API_STATUS_IMPL(winmla::SessionGetInputRequiredDeviceId, _In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id) { + auto device = GetSessionGetInputDevice(session, input_name); + *device_id = device.Id(); + return nullptr; +} + +ORT_API_STATUS_IMPL(winmla::ValueGetDeviceId, _In_ OrtValue* ort_value, _Out_ int16_t* device_id) { + auto device = ort_value->Get().Location().device; + *device_id = device.Id(); + return nullptr; +} + +ORT_API_STATUS_IMPL(winmla::SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_ const char* const input_name, + _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const onnxruntime::SessionState& session_state = session_protected_load_accessor->GetSessionState(); + + auto ort_value = std::make_unique(); + auto status = onnxruntime::utils::CopyOneInputAcrossDevices(session_state, input_name, *orig_value, *ort_value.get()); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + + *new_value = ort_value.release(); + + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/api/Windows.AI.MachineLearning.idl b/winml/api/Windows.AI.MachineLearning.idl index 0380f4a02b7b1..7dddba0afdbb4 100644 --- a/winml/api/Windows.AI.MachineLearning.idl +++ b/winml/api/Windows.AI.MachineLearning.idl @@ -20,7 +20,7 @@ import "windows.storage.idl"; namespace Windows.AI.MachineLearning { - [contractversion(4)] + [contractversion(3)] apicontract MachineLearningContract{}; //! Forward declarations @@ -334,18 +334,6 @@ namespace Windows.AI.MachineLearning TensorKind KeyKind{ get; }; //! Returns the properties of the map's value. ILearningModelFeatureDescriptor ValueDescriptor{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] MapFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind KeyKind, - ILearningModelFeatureDescriptor ValueDescriptor - ); - } } //! \class SequenceFeatureDescriptor @@ -358,17 +346,6 @@ namespace Windows.AI.MachineLearning { //! Gets the properties of the specified feature. ILearningModelFeatureDescriptor ElementDescriptor{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] SequenceFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - ILearningModelFeatureDescriptor ElementDescriptor - ); - } } //! \class TensorFeatureDescriptor @@ -383,23 +360,6 @@ namespace Windows.AI.MachineLearning TensorKind TensorKind{ get; }; //! Returns the count and size of each dimension. Windows.Foundation.Collections.IVectorView Shape{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] TensorFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind TensorKind, - Int64[] Shape, - Boolean HasUnsupportedImageMetadata - ); - //! if this feature is an image but has unsupport image metadata (like a BitmapPixelFormat) - //! you can still use the runtime but without image conversion support. - //! this setting will be 'true' - Boolean HasUnsupportedImageMetadata{ get; } ; - } } //! \class ImageFeatureDescriptor @@ -418,26 +378,6 @@ namespace Windows.AI.MachineLearning UInt32 Width{ get; }; //! The height of the image. UInt32 Height{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] ImageFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind TensorKind, - Int64[] Shape, - Windows.Graphics.Imaging.BitmapPixelFormat BitmapPixelFormat, - Windows.Graphics.Imaging.BitmapAlphaMode BitmapAlphaMode, - UInt32 Width, - UInt32 Height - ); - - //! Returns the data type of the tensor. This is useful if you want to know - //! if it's fp16 or fp32 - TensorKind TensorKind{ get; }; - } } //! \interface ITensor diff --git a/winml/dll/module.cpp b/winml/dll/module.cpp index 531521edc834a..8c7123f880c85 100644 --- a/winml/dll/module.cpp +++ b/winml/dll/module.cpp @@ -6,19 +6,19 @@ #include #include "LearningModelDevice.h" +#include "OnnxruntimeProvider.h" using namespace winrt::Windows::AI::MachineLearning::implementation; -void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const &failure) WI_NOEXCEPT { +void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const& failure) WI_NOEXCEPT { if (!alreadyReported) { winrt::hstring message(failure.pszMessage ? failure.pszMessage : L""); telemetry_helper.LogRuntimeError( - failure.hr, - winrt::to_string(message), - failure.pszFile, - failure.pszFunction, - failure.uLineNumber - ); + failure.hr, + winrt::to_string(message), + failure.pszFile, + failure.pszFunction, + failure.uLineNumber); } } @@ -57,10 +57,10 @@ extern "C" BOOL WINAPI DllMain(_In_ HINSTANCE hInstance, DWORD dwReason, _In_ vo } extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry) try { - *registry = nullptr; - winrt::com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - return adapter->GetCustomRegistry(registry); + winrt::com_ptr engine_factory; + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory.put())); + WINML_THROW_IF_FAILED(engine_factory->CreateCustomRegistry(registry)); + return S_OK; } CATCH_RETURN(); diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp new file mode 100644 index 0000000000000..9aeedb7613099 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" + +#include "OnnxruntimeCpuSessionBuilder.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" + +using namespace Windows::AI::MachineLearning; + +HRESULT OnnxruntimeCpuSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { + engine_factory_ = engine_factory; + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::CreateSessionOptions( + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtSessionOptions* ort_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); + + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + // set the graph optimization level to all (used to be called level 3) + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); + + // Onnxruntime will use half the number of concurrent threads supported on the system + // by default. This causes MLAS to not exercise every logical core. + // We force the thread pool size to be maxxed out to ensure that WinML always + // runs the fastest. + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency()), + ort_api); + +#ifndef _WIN64 + auto use_arena = false; +#else + auto use_arena = true; +#endif + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); + + // call release() so the underlying OrtSessionOptions object isn't freed + *options = session_options.release(); + + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::CreateSession( + OrtSessionOptions* options, + OrtSession** session) { + RETURN_HR_IF_NULL(E_POINTER, session); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); + + OrtSession* ort_session_raw; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); + + auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); + + *session = ort_session.release(); + + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::Initialize( + OrtSession* session) { + RETURN_HR_IF_NULL(E_INVALIDARG, session); + + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); + + return S_OK; +} diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h new file mode 100644 index 0000000000000..d9f4a12375316 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "OnnxruntimeSessionBuilder.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +class OnnxruntimeCpuSessionBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IOrtSessionBuilder> { + public: + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory); + + HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions** options) override; + + HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions* options, + OrtSession** session) override; + + HRESULT STDMETHODCALLTYPE Initialize( + OrtSession* session) override; + + private: + Microsoft::WRL::ComPtr engine_factory_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/FeatureDescriptorFactory.cpp b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp similarity index 58% rename from winml/adapter/FeatureDescriptorFactory.cpp rename to winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp index 196ccf517dc64..db4cf60062a45 100644 --- a/winml/adapter/FeatureDescriptorFactory.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp @@ -5,7 +5,7 @@ #include -#include "FeatureDescriptorFactory.h" +#include "OnnxruntimeDescriptorConverter.h" #include "ImageFeatureDescriptor.h" #include "MapFeatureDescriptor.h" #include "SequenceFeatureDescriptor.h" @@ -13,7 +13,11 @@ #include "winrt/windows.foundation.collections.h" #include "winrt/windows.graphics.imaging.h" -#include "WinMLAdapter.h" + +#include "OnnxruntimeEngine.h" + +#include "OnnxruntimeErrors.h" + using namespace winrt::Windows::AI::MachineLearning; // BitmapPixelFormat constants @@ -42,152 +46,64 @@ static const char* c_supported_nominal_ranges[] = namespace Windows::AI::MachineLearning { - -// since this code is now running inside ONNXRUNTIME we need to shortcut -// this a bit when creating winrt objects. This will help. - -/* extern "C" -HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; - -#ifdef _M_IX86 -#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") -#else -#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") -#endif -*/ - -bool starts_with(std::wstring_view value, std::wstring_view match) noexcept -{ - return 0 == value.compare(0, match.size(), match); -} - -EXTERN_C IMAGE_DOS_HEADER __ImageBase; - -std::wstring GetModulePath() -{ - std::wstring val; - wchar_t modulePath[MAX_PATH] = { 0 }; - GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath)); - wchar_t drive[_MAX_DRIVE]; - wchar_t dir[_MAX_DIR]; - wchar_t filename[_MAX_FNAME]; - wchar_t ext[_MAX_EXT]; - _wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); - - val = drive; - val += dir; - - return val; -} - -extern "C" int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept { - *factory = nullptr; - HSTRING classId_hstring = (HSTRING)classId; - std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; - HMODULE library{ nullptr }; - - std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll"; - - if (starts_with(name, L"Windows.AI.MachineLearning.")) - { - const wchar_t* libPath = winmlDllPath.c_str(); - library = LoadLibraryW(libPath); - } - else - { - return RoGetActivationFactory(classId_hstring, iid, factory); - } - - if (!library) - { - return HRESULT_FROM_WIN32(GetLastError()); - } - - using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); - auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); - - if (!call) - { - HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); - WINRT_VERIFY(FreeLibrary(library)); - return hr; - } - - winrt::com_ptr activation_factory; - HRESULT const hr = call(classId_hstring, activation_factory.put_void()); - - if (FAILED(hr)) - { - WINRT_VERIFY(FreeLibrary(library)); - return hr; - } - - if (winrt::guid(iid) != winrt::guid_of()) - { - return activation_factory->QueryInterface(iid, factory); - } - - *factory = activation_factory.detach(); - return S_OK; -} - // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata); static TensorKind -TensorKindFromOnnxDataType( - ONNX_NAMESPACE::TensorProto_DataType dataType) { - using TensorType = ONNX_NAMESPACE::TensorProto_DataType; +TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) { switch (dataType) { - case TensorType::TensorProto_DataType_BOOL: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { return TensorKind::Boolean; } - case TensorType::TensorProto_DataType_STRING: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { return TensorKind::String; } - case TensorType::TensorProto_DataType_FLOAT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { return TensorKind::Float16; } - case TensorType::TensorProto_DataType_FLOAT: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { return TensorKind::Float; } - case TensorType::TensorProto_DataType_DOUBLE: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { return TensorKind::Double; } - case TensorType::TensorProto_DataType_INT8: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { return TensorKind::Int8; } - case TensorType::TensorProto_DataType_INT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { return TensorKind::Int16; } - case TensorType::TensorProto_DataType_INT32: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { return TensorKind::Int32; } - case TensorType::TensorProto_DataType_INT64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { return TensorKind::Int64; } - case TensorType::TensorProto_DataType_UINT8: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { return TensorKind::UInt8; } - case TensorType::TensorProto_DataType_UINT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { return TensorKind::UInt16; } - case TensorType::TensorProto_DataType_UINT32: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { return TensorKind::UInt32; } - case TensorType::TensorProto_DataType_UINT64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { return TensorKind::UInt64; } - case TensorType::TensorProto_DataType_COMPLEX64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: { return TensorKind::Complex64; } - case TensorType::TensorProto_DataType_COMPLEX128: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { return TensorKind::Complex128; } - default: { return TensorKind::Undefined; } + default: { + return TensorKind::Undefined; + } } } @@ -240,26 +156,10 @@ TensorKindToString(TensorKind tensorKind) { return "complex128"; } case TensorKind::Undefined: - default: { return "undefined"; } - } -} - -static std::vector -ConvertShapeProtoToVector( - const ::onnx::TensorShapeProto& shape_proto) { - std::vector shape; - for (int i = 0; i < shape_proto.dim_size(); i++) { - auto& dim = shape_proto.dim(i); - if (dim.has_dim_param()) { - shape.push_back(-1); - } else if (dim.has_dim_value()) { - shape.push_back(dim.dim_value()); - } else { - winrt::throw_hresult(E_INVALIDARG); + default: { + return "undefined"; } } - - return shape; } static const char* @@ -410,16 +310,16 @@ enum class TensorType { Tensor_Data, static TensorType GetTensorType( - const ::onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + OrtTypeInfo* type_info, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); + const char* denotation; + size_t len; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len), + engine_factory->UseOrtApi()); - THROW_HR_IF_MSG( - E_FAIL, - type_proto.has_tensor_type() == false, - "Malformed onnx file."); - - auto has_image_denotation = type_proto.denotation() == "IMAGE"; + constexpr char c_image[] = "IMAGE"; + auto has_image_denotation = strncmp(denotation, c_image, _countof(c_image)) == 0; if (!has_image_denotation) { return TensorType::Tensor_Data; } @@ -430,9 +330,15 @@ GetTensorType( // Check if the tensor value_info_proto is of type float. // IMAGE tensors MUST be of type float - const auto& tensor_type = type_proto.tensor_type(); - auto tensor_kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + + auto tensor_kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); auto is_float_tensor = tensor_kind == TensorKind::Float; if (!is_float_tensor) { log_stream << "Unsupported image with " << TensorKindToString(tensor_kind) @@ -471,7 +377,7 @@ GetTensorType( has_unsupported_image_metadata); if (is_tensor_improperly_annotated_as_image) { - TraceLoggingWrite(winmla::winml_trace_logging_provider, + TraceLoggingWrite(winml_trace_logging_provider, "WinMLInputValidation", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_WARNING), @@ -491,21 +397,35 @@ GetTensorType( static winml::ILearningModelFeatureDescriptor CreateTensorFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata, bool has_unsupported_image_metadata) { - const auto& type_proto = value_info_proto->type(); - const auto& tensor_type = type_proto.tensor_type(); - auto shape = WinML::ConvertShapeProtoToVector(tensor_type.shape()); - auto kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); - - TensorFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), // description - value_info_proto->name().empty() == false, // is_required + auto type_info = feature_descriptor->type_info_.get(); + + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + size_t num_dims; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); + + auto shape = std::vector(num_dims); + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + + auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); + + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, // description kind, shape, + feature_descriptor->name_length_ > 0, // is_required has_unsupported_image_metadata); return descriptor.as(); @@ -513,13 +433,27 @@ CreateTensorFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateImageFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - const auto& tensor_type = type_proto.tensor_type(); - auto shape = WinML::ConvertShapeProtoToVector(tensor_type.shape()); - auto kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + + size_t num_dims; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); + + auto shape = std::vector(num_dims); + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); // pixel format and alpha auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key); @@ -527,18 +461,13 @@ CreateImageFeatureDescriptor( auto pixel_format = format_info.first; auto alpha_mode = format_info.second; - // paulm: commenting this out during layering. gamma and nominal are never used - // since we only support one of them. if a non support one is set, they all fall back - // to TensorFeatureDescriptor (invalid image metadata) -#ifdef DONE_LAYERING // color space gamma value - auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); - auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); + auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); + auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); // nominal range - auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); - auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); -#endif + auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); + auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); // The current code assumes that the shape will be in NCHW. // Should the model metadata be read instead??? @@ -546,42 +475,59 @@ CreateImageFeatureDescriptor( const int c_width_dimension = 3; auto height = static_cast(shape[c_height_dimension]); auto width = static_cast(shape[c_width_dimension]); - ImageFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_required + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, kind, shape, + feature_descriptor->name_length_ > 0, // is_required pixel_format, alpha_mode, width, - height); + height, + nominal_range, + color_space_gamma); return descriptor.as(); } static winml::ILearningModelFeatureDescriptor CreateMapFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - auto type_proto_map = type_proto.map_type(); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtMapTypeInfo* map_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType map_key_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type), + engine_factory->UseOrtApi()); - auto key_kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(type_proto_map.key_type())); + auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type); - onnx::ValueInfoProto dummy_value_info_proto; - dummy_value_info_proto.set_name(value_info_proto->name().c_str()); - dummy_value_info_proto.set_doc_string(value_info_proto->doc_string().c_str()); - *dummy_value_info_proto.mutable_type() = type_proto_map.value_type(); + OrtTypeInfo* map_value_type_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info), + engine_factory->UseOrtApi()); + + UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); + + OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; + dummy_ort_value_info_wrapper.description_ = feature_descriptor->description_; + dummy_ort_value_info_wrapper.description_length_ = feature_descriptor->description_length_; + dummy_ort_value_info_wrapper.name_ = feature_descriptor->name_; + dummy_ort_value_info_wrapper.name_length_ = feature_descriptor->name_length_; + dummy_ort_value_info_wrapper.type_info_ = std::move(unique_map_value_type_info); auto value_descriptor = - CreateFeatureDescriptor(&dummy_value_info_proto, metadata); + CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); - MapFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_rRequired + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, + feature_descriptor->name_length_ > 0, // is_required key_kind, value_descriptor); return descriptor.as(); @@ -589,24 +535,35 @@ CreateMapFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateSequenceFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - // assert(typeProto->has_sequence_type()); - auto type_proto_sequence = type_proto.sequence_type(); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtSequenceTypeInfo* sequence_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info), + engine_factory->UseOrtApi()); + + OrtTypeInfo* sequence_element_type_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info), + engine_factory->UseOrtApi()); - onnx::ValueInfoProto dummy_value_info_proto; - dummy_value_info_proto.set_name(value_info_proto->name().c_str()); - dummy_value_info_proto.set_doc_string(value_info_proto->doc_string().c_str()); - *dummy_value_info_proto.mutable_type() = type_proto_sequence.elem_type(); + UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); + + OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; + dummy_ort_value_info_wrapper.description_ = feature_descriptor->description_; + dummy_ort_value_info_wrapper.description_length_ = feature_descriptor->description_length_; + dummy_ort_value_info_wrapper.name_ = feature_descriptor->name_; + dummy_ort_value_info_wrapper.name_length_ = feature_descriptor->name_length_; + dummy_ort_value_info_wrapper.type_info_ = std::move(unique_sequence_element_type_info); auto element_descriptor = - CreateFeatureDescriptor(&dummy_value_info_proto, metadata); + CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); - SequenceFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_required + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, + feature_descriptor->name_length_ > 0, // is_required element_descriptor); return descriptor.as(); @@ -614,36 +571,43 @@ CreateSequenceFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); + auto type_info = feature_descriptor->type_info_.get(); + + ONNXType onnx_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type), + engine_factory->UseOrtApi()); - using ValueCase = ::onnx::TypeProto::ValueCase; - switch (type_proto.value_case()) { - case ValueCase::kTensorType: { - auto tensor_type = - GetTensorType(value_info_proto, metadata); + switch (onnx_type) { + case ONNXType::ONNX_TYPE_TENSOR: { + auto tensor_type = GetTensorType(engine_factory, type_info, metadata); if (tensor_type == TensorType::Tensor_Image) { return CreateImageFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } else { auto has_unsupported_image_metadata = tensor_type == TensorType::Tensor_Data_UnsupportedImageMetadata; return CreateTensorFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata, has_unsupported_image_metadata); } } - case ValueCase::kMapType: { + case ONNXType::ONNX_TYPE_MAP: { return CreateMapFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } - case ValueCase::kSequenceType: { + case ONNXType::ONNX_TYPE_SEQUENCE: { return CreateSequenceFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } default: @@ -651,18 +615,17 @@ CreateFeatureDescriptor( } } -FeatureDescriptorFactory::FeatureDescriptorFactory( - const std::unordered_map& metadata) : metadata_(metadata) {} +OnnxruntimeDescriptorConverter::OnnxruntimeDescriptorConverter( + OnnxruntimeEngineFactory* engine_factory, + const std::unordered_map& metadata) : engine_factory_(engine_factory), metadata_(metadata) {} wfc::IVector -FeatureDescriptorFactory::CreateDescriptorsFromValueInfoProtos( - const std::vector& value_info_protos) { - auto features = - winrt::single_threaded_vector(); - - for (auto value_info_proto : value_info_protos) { - auto descriptor = WinML::CreateFeatureDescriptor(value_info_proto, metadata_); - features.Append(descriptor); +OnnxruntimeDescriptorConverter::ConvertToLearningModelDescriptors(const std::vector& descriptors) { + auto features = winrt::single_threaded_vector(); + + for (const auto& descriptor : descriptors) { + auto learning_model_descriptor = WinML::CreateFeatureDescriptor(engine_factory_.Get(), &descriptor, metadata_); + features.Append(learning_model_descriptor); } return features; diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h new file mode 100644 index 0000000000000..4fab8bc443e7a --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#pragma once + +#include "pch.h" + +namespace Windows::AI::MachineLearning { + +struct OnnxruntimeValueInfoWrapper { + OnnxruntimeValueInfoWrapper() : type_info_(UniqueOrtTypeInfo(nullptr, nullptr)) {} + const char* name_ = nullptr; + size_t name_length_ = 0; + const char* description_ = nullptr; + size_t description_length_ = 0; + UniqueOrtTypeInfo type_info_; +}; + +class OnnxruntimeEngineFactory; + +struct OnnxruntimeDescriptorConverter { + OnnxruntimeDescriptorConverter( + OnnxruntimeEngineFactory* engine_factory, + const std::unordered_map& model_metadata); + + wfc::IVector + ConvertToLearningModelDescriptors(const std::vector& descriptors); + + private: + Microsoft::WRL::ComPtr engine_factory_; + const std::unordered_map& metadata_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp new file mode 100644 index 0000000000000..8f23e6864f73e --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" + +#ifdef USE_DML + +#include "OnnxruntimeDmlSessionBuilder.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" +#include "LearningModelDevice.h" + +using namespace Windows::AI::MachineLearning; + +HRESULT OnnxruntimeDmlSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue) { + engine_factory_ = engine_factory; + device_.copy_from(device); + queue_.copy_from(queue); + return S_OK; +} + +HRESULT +OnnxruntimeDmlSessionBuilder::CreateSessionOptions( + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtSessionOptions* ort_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); + + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + // set the graph optimization level to all (used to be called level 3) + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); + + // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. + RETURN_HR_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), + ort_api); + + // Request the dml ep + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()), + ort_api); + +#ifndef _WIN64 + auto use_arena = false; +#else + auto use_arena = true; +#endif + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); + + // call release() so the underlying OrtSessionOptions object isn't freed + *options = session_options.release(); + + return S_OK; +} + +HRESULT OnnxruntimeDmlSessionBuilder::CreateSession( + OrtSessionOptions* options, + OrtSession** session) { + RETURN_HR_IF_NULL(E_POINTER, session); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); + + OrtSession* ort_session_raw; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); + auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); + + *session = ort_session.release(); + + return S_OK; +} + +HRESULT OnnxruntimeDmlSessionBuilder::Initialize( + OrtSession* session) { + RETURN_HR_IF_NULL(E_INVALIDARG, session); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true), + engine_factory_->UseOrtApi()); + + // Flush the D3D12 work from the DML execution provider + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +#endif USE_DML \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h new file mode 100644 index 0000000000000..0f651e823a532 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "OnnxruntimeSessionBuilder.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +class OnnxruntimeDmlSessionBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IOrtSessionBuilder> { + public: + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue); + + HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions** options) override; + + HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions* options, + OrtSession** session) override; + + HRESULT STDMETHODCALLTYPE Initialize( + OrtSession* session) override; + + private: + Microsoft::WRL::ComPtr engine_factory_; + winrt::com_ptr device_; + winrt::com_ptr queue_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp new file mode 100644 index 0000000000000..05b114bdaf6d7 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -0,0 +1,1265 @@ +#include "pch.h" + +#include "OnnxruntimeEngine.h" + +#include "PheonixSingleton.h" +#include "OnnxruntimeEnvironment.h" +#include "OnnxruntimeEngineBuilder.h" +#include "OnnxruntimeModel.h" +#include "OnnxruntimeSessionBuilder.h" +#include "OnnxruntimeErrors.h" + +using namespace WinML; + +static const OrtApi* GetVersionedOrtApi() { + static const uint32_t ort_version = 1; + const auto ort_api_base = OrtGetApiBase(); + return ort_api_base->GetApi(ort_version); +} + +static const WinmlAdapterApi* GetVersionedWinmlAdapterApi() { + return OrtGetWinMLAdapter(GetVersionedOrtApi()); +} + +static ONNXTensorElementDataType +ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) { + switch (kind) { + case winml::TensorKind::Boolean: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + } + case winml::TensorKind::String: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } + case winml::TensorKind::Float16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } + case winml::TensorKind::Float: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + case winml::TensorKind::Double: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } + case winml::TensorKind::Int8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } + case winml::TensorKind::Int16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + } + case winml::TensorKind::Int32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } + case winml::TensorKind::Int64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + case winml::TensorKind::UInt8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } + case winml::TensorKind::UInt16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + } + case winml::TensorKind::UInt32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + } + case winml::TensorKind::UInt64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + } + case winml::TensorKind::Complex64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; + } + case winml::TensorKind::Complex128: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; + } + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } + } +} + +OnnxruntimeValue::OnnxruntimeValue() : value_(nullptr, nullptr), allocator_(nullptr, nullptr) {} + +OnnxruntimeValue::~OnnxruntimeValue() { + value_.reset(nullptr); + allocator_.reset(nullptr); +} + +HRESULT OnnxruntimeValue::RuntimeClassInitialize(OnnxruntimeEngine* engine, UniqueOrtValue&& ort_value, UniqueOrtAllocator&& allocator) { + engine_ = engine; + value_ = std::move(ort_value); + allocator_ = std::move(allocator); + + return S_OK; +} + +HRESULT OnnxruntimeValue::IsEmpty(bool* out) { + *out = UseOrtValue() == nullptr; + return S_OK; +} + +HRESULT OnnxruntimeValue::IsCpu(bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + OrtMemoryInfo* ort_memory_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info), + ort_api); + auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo); + + const char* name; + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name), + ort_api); + + OrtMemType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type), + ort_api); + + *out = !strcmp(name, "Cpu") || + type == OrtMemType::OrtMemTypeCPUOutput || + type == OrtMemType::OrtMemTypeCPUInput; + return S_OK; +} + +static int64_t ShapeSize(const int64_t* shape, size_t count) { + // for each dim + int64_t size = 1; + for (int i = 0; i < count; i++) { + // find out it's total size + size *= shape[i]; + // make sure there are no invalid dimensions (-1 or any invalid shape) + THROW_HR_IF(E_INVALIDARG, shape[i] <= 0); + } + return size; +} + +static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value, + OrtTensorTypeAndShapeInfo* type_and_shape_info) { + std::vector out; + + size_t size; + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info, &size), + ort_api); + + std::vector shape(size); + THROW_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info, &shape[0], size), + ort_api); + + auto length = ShapeSize(shape.data(), shape.size()); + + // make a big buffer to hold all the string data + size_t buffer_length; + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorDataLength(ort_value, &buffer_length), + ort_api); + + std::vector strings; + std::unique_ptr buffer(new uint8_t[buffer_length]); + std::vector offsets(length); + + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorContent(ort_value, buffer.get(), buffer_length, offsets.data(), offsets.size()), + ort_api); + + // now go build all the strings + for (auto i = 0; i < length; ++i) { + size_t str_len = 0; + // are we on the last one? + if (i == (length - 1)) { + str_len = buffer_length - offsets[i]; + } else { + str_len = offsets[i + 1] - offsets[i]; + } + strings.push_back(std::string_view(reinterpret_cast(buffer.get() + offsets[i]), str_len)); + } + + return std::make_shared>(std::move(strings), std::move(buffer)); +} + +HRESULT OnnxruntimeValue::GetResource(WinML::Resource& out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + void* mutable_data = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data), + ort_api); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider), + ort_api); + + bool is_cpu = false; + if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) { + void* resource; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data, + reinterpret_cast(&resource)), + ort_api); + out = WinML::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); + } else { + int is_tensor; + RETURN_HR_IF_NOT_OK_MSG(ort_api->IsTensor(value_.get(), &is_tensor), + ort_api); + if (is_tensor == 0) { + out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + return S_OK; + } + + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + ONNXTensorElementDataType data_type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type), + ort_api); + + if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + auto strings = GetStrings(ort_api, value_.get(), info); + auto string_data = strings->first.data(); + out = WinML::Resource(string_data, [capture_strings = strings](void*) { /*This deleter does nothing but capture the strings, which extends the lifetime of the returned strings.*/ }); + } else { + out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + } + } + return S_OK; +} + +HRESULT OnnxruntimeValue::IsTensor(bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + + ONNXType type = ONNXType::ONNX_TYPE_UNKNOWN; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueType(value_.get(), &type), + ort_api); + *out = type == ONNXType::ONNX_TYPE_TENSOR; + return S_OK; +} + +HRESULT OnnxruntimeValue::IsOfTensorType(winml::TensorKind kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + ONNXTensorElementDataType data_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type), + ort_api); + + *out = data_type == ONNXTensorElementDataTypeFromTensorKind(kind); + return S_OK; +} + +HRESULT OnnxruntimeValue::GetTensorShape(std::vector& shape_vector) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + size_t size; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info.get(), &size), + ort_api); + + std::vector shape(size); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info.get(), &shape[0], size), + ort_api); + + shape_vector = std::move(shape); + return S_OK; +} + +static bool EnsureMapTypeInfo(OnnxruntimeEngine* engine, OrtTypeInfo* type_info, winml::TensorKind key_kind, winml::TensorKind value_kind) { + auto ort_api = engine->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine->GetEngineFactory()->UseWinmlAdapterApi(); + + const OrtMapTypeInfo* map_info; + THROW_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToMapTypeInfo(type_info, &map_info), + ort_api); + + ONNXTensorElementDataType map_key_type; + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapKeyType(map_info, &map_key_type), + ort_api); + + if (map_key_type == ONNXTensorElementDataTypeFromTensorKind(key_kind)) { + OrtTypeInfo* value_info; + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapValueType(map_info, &value_info), + ort_api); + auto map_value_info = UniqueOrtTypeInfo(value_info, ort_api->ReleaseTypeInfo); + + const OrtTensorTypeAndShapeInfo* value_tensor_info = nullptr; + THROW_IF_NOT_OK_MSG(ort_api->CastTypeInfoToTensorInfo(map_value_info.get(), &value_tensor_info), + ort_api); + + if (value_tensor_info) { + ONNXTensorElementDataType map_value_tensor_type; + THROW_IF_NOT_OK_MSG(ort_api->GetTensorElementType(value_tensor_info, &map_value_tensor_type), + ort_api); + + if (map_value_tensor_type == ONNXTensorElementDataTypeFromTensorKind(value_kind)) { + size_t num_dims; + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(value_tensor_info, &num_dims), + ort_api); + + return num_dims == 0; + } + } + } + return false; +} + +HRESULT OnnxruntimeValue::IsOfMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + + OrtTypeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); + auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); + + ONNXType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); + + if (type == ONNXType::ONNX_TYPE_MAP) { + *out = EnsureMapTypeInfo(engine_.Get(), unique_type_info.get(), key_kind, value_kind); + } + + *out = false; + + return S_OK; +} + +HRESULT OnnxruntimeValue::IsOfVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + OrtTypeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); + auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); + + ONNXType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); + + if (type == ONNXType::ONNX_TYPE_SEQUENCE) { + const OrtSequenceTypeInfo* sequence_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info), + ort_api); + + OrtTypeInfo* element_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetSequenceElementType(sequence_info, &element_info), + ort_api); + auto unique_element_info = UniqueOrtTypeInfo(element_info, ort_api->ReleaseTypeInfo); + + *out = EnsureMapTypeInfo(engine_.Get(), unique_element_info.get(), key_kind, value_kind); + } + return S_OK; +} + +HRESULT OnnxruntimeValue::SetParameter(IUnknown* param) { + param_ = param; + return S_OK; +} + +OrtValue* OnnxruntimeValue::UseOrtValue() { + return value_.get(); +} + +HRESULT OnnxruntimeValue::AssignOrtValue(OrtValue* in) { + value_.reset(in); + return S_OK; +} + +OnnxruntimeEngine::OnnxruntimeEngine() : session_(nullptr, nullptr) { +} + +HRESULT OnnxruntimeEngine::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, + UniqueOrtSession&& session, + IOrtSessionBuilder* session_builder) { + engine_factory_ = engine_factory; + session_ = std::move(session); + session_builder_ = session_builder; + return S_OK; +} + +HRESULT OnnxruntimeEngine::LoadModel(_In_ IModel* model) { + Microsoft::WRL::ComPtr onnxruntime_model; + RETURN_IF_FAILED(model->QueryInterface(IID_PPV_ARGS(&onnxruntime_model))); + + OrtModel* ort_model; + RETURN_IF_FAILED(onnxruntime_model->DetachOrtModel(&ort_model)); + + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session_.get(), ort_model), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::Initialize() { + RETURN_IF_FAILED(session_builder_->Initialize(session_.get())); + return S_OK; +} + +HRESULT OnnxruntimeEngine::RegisterGraphTransformers() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterGraphTransformers(session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::RegisterCustomRegistry(IMLOperatorRegistry* registry) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session_.get(), registry), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::EndProfiling() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionEndProfiling(session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::StartProfiling() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + engine_factory_->GetOrtEnvironment(&ort_env); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionStartProfiling(ort_env, session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::FlushContext() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::TrimUploadHeap() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + auto src_value = static_cast(src); + auto dest_value = static_cast(dest); + + bool is_empty; + auto has_null_source = (SUCCEEDED(src_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_source); + + auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_dest); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::Sync() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +OrtSession* OnnxruntimeEngine::UseOrtSession() { + return session_.get(); +} + +const OrtApi* OnnxruntimeEngine::UseOrtApi() { + return engine_factory_->UseOrtApi(); +} + +OnnxruntimeEngineFactory* OnnxruntimeEngine::GetEngineFactory() { + return engine_factory_.Get(); +} + +/* +* OnnxruntimeEngine::CreateTensorValue +* +* Used by callers like ImageFeatureValue to allocate a cpu or gpu OrtValue with ORT owned memory. +* In the image feature value case, tensorization creates temporary buffers, and will need to copy the value from +* its source location to the ort value. Since a copy is required, there is need to preserve the caller's memory locations. +* We simply allocate memory with ORT and copy the tensorized values into it. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator), + engine_factory_->UseOrtApi()); + + auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything + + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), std::move(unique_allocator))); + return S_OK; +} + +using DmlAllocatorResource = std::unique_ptr; +class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IUnknown> { + public: + DmlAllocatorWrapper() : dml_resource_(nullptr, nullptr) {} + + HRESULT RuntimeClassInitialize(DmlAllocatorResource&& dml_resource) { + dml_resource_ = std::move(dml_resource); + return S_OK; + } + + private: + DmlAllocatorResource dml_resource_; +}; + +/* +* OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource +* +* Used by callers like TensorBase to allocate a gpu OrtValue based on a called owned ID3D12Resource. +* WinML cannot use ORT allocators here since they will allocate the ID3D12Resource and force a copy from the user provided value. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + OrtMemoryInfo* dml_memory = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory), + engine_factory_->UseOrtApi()); + + void* dml_allocator_resource; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), + engine_factory_->UseOrtApi()); + + auto unique_dml_allocator_resource = + DmlAllocatorResource(dml_allocator_resource, + [](void* ptr) { + GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr); + }); + + // create the OrtValue as a tensor letting ort know that we own the data buffer + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + dml_memory, + unique_dml_allocator_resource.get(), + d3d_resource->GetDesc().Width, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + + Microsoft::WRL::ComPtr out_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&out_value, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + + // Cache the allocator on the value so it destructs appropriately when the value is dropped + Microsoft::WRL::ComPtr dml_allocator_resource_wrapper; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&dml_allocator_resource_wrapper, std::move(unique_dml_allocator_resource))); + + RETURN_IF_FAILED(out_value->SetParameter(dml_allocator_resource_wrapper.Get())); + + *out = out_value.Detach(); + + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy +* +* Used by callers like TensorString to allocate a cpu OrtValue and populate the contents with use specified data. +* WinML cannot use CreateTensorWithDataAsOrtValue since externally allocated strings are not supported on the c-abi. +* The c-abi string implementation requires a copy the external buffer into its own internal std::string copy. +* In addition, strings have different APIs on the c-abi like FillStringTensor to populate the buffer, and so strings +* have a different calling pattern than other Tensor types of simple data types. +*/ +HRESULT OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy(const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + RETURN_IF_FAILED(CreateTensorValue(shape, count, winml::TensorKind::String, out)); + + auto ort_value = reinterpret_cast(*out)->UseOrtValue(); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(ort_value, reinterpret_cast(data), num_elements), + ort_api); + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateTensorValueFromExternalBuffer +* +* Used by callers like TensorBase to allocate a cpu OrtValue that is backed by caller owned memory. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalBuffer(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + + if (kind == winml::TensorKind::String) { + // String buffers cannot be passed into the ort api directly because ort c-api tensor strings cannot be backed by external memory + return E_NOTIMPL; + } + + // TODO: what is the difference between the device allocator and the arena allocator? + OrtMemoryInfo* cpu_memory; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory), + ort_api); + + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + cpu_memory, + data, + size_in_bytes, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateNullValue +* +* Used by callers like TensorBase and the binding object to allocate a cpu OrtValue that is empty. +* This is used for WinML unbound outputs. +*/ +HRESULT OnnxruntimeEngine::CreateNullValue(_Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto unique_value = UniqueOrtValue(nullptr, ort_api->ReleaseValue); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +template +struct AbiTypeInfo { + using CppWinRTType = TAbiType; + using OrtType = TAbiType; + using ResourceType = TAbiType; +}; + +template <> +struct AbiTypeInfo { + using CppWinRTType = winrt::hstring; + using OrtType = const char*; + using ResourceType = std::string_view; +}; + +template +typename auto CppwinrtTypeToOrtType(TCppwinrtType raw) { + return raw; +} + +template <> +typename auto CppwinrtTypeToOrtType(winrt::hstring raw) { + return WinML::Strings::UTF8FromHString(raw); +} + +template +typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::ResourceType value) { + return value; +} + +template <> +typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::ResourceType value) { + return WinML::Strings::HStringFromUTF8(value.data(), value.size()); +} + +template +auto CastToWinrtMap(IInspectable* map_insp) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + + ::winrt::Windows::Foundation::IInspectable map_inspectable; + ::winrt::Windows::Foundation::Collections::IMap map; + winrt::copy_from_abi(map_inspectable, map_insp); + map_inspectable.as(map); + return map; +} + +template +auto CastToWinrtSequenceOfMaps(IInspectable* sequence_insp) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + + using cppwinrt_element_map_type = ::winrt::Windows::Foundation::Collections::IMap; + using cppwinrt_sequence_type = ::winrt::Windows::Foundation::Collections::IVector; + cppwinrt_sequence_type sequence; + ::winrt::Windows::Foundation::IInspectable sequence_inspectable; + winrt::copy_from_abi(sequence_inspectable, sequence_insp); + sequence_inspectable.as(sequence); + return sequence; +} + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* keys_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); + + AbiTypeInfo::OrtType* values_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + for (const auto& pair : map) { + keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key()); + values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value()); + index++; + } + return S_OK; + } +}; + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* values_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector keys; + for (const auto& pair : map) { + keys.push_back(CppwinrtTypeToOrtType(pair.Key())); + values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value()); + index++; + } + + std::vector raw_values; + std::transform( + keys.begin(), + keys.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); + + return S_OK; + } +}; + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* keys_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector values; + for (const auto& pair : map) { + keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key()); + values.push_back(CppwinrtTypeToOrtType(pair.Value())); + index++; + } + + std::vector raw_values; + std::transform( + values.begin(), + values.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); + return S_OK; + } +}; + +template <> +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector keys; + std::vector values; + for (const auto& pair : map) { + keys.push_back(CppwinrtTypeToOrtType(pair.Key())); + values.push_back(CppwinrtTypeToOrtType(pair.Value())); + index++; + } + + std::vector raw_keys; + std::transform( + keys.begin(), + keys.end(), + std::back_inserter(raw_keys), + [&](auto& str) { return str.c_str(); }); + + std::vector raw_values; + std::transform( + values.begin(), + values.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_keys.data(), raw_keys.size()), + ort_api); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(values_ort_value, raw_values.data(), raw_values.size()), + ort_api); + return S_OK; + } +}; + +template +HRESULT CreateMapValue(OnnxruntimeEngine* engine, IInspectable* map_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + auto ort_api = engine->UseOrtApi(); + auto map = CastToWinrtMap(map_insp); + std::vector shape = {static_cast(map.Size())}; + + winrt::com_ptr key_value; + RETURN_IF_FAILED(engine->CreateTensorValue(shape.data(), shape.size(), key_kind, key_value.put())); + auto keys_ort_value = static_cast(key_value.get())->UseOrtValue(); + + winrt::com_ptr value_value; + RETURN_IF_FAILED(engine->CreateTensorValue(shape.data(), shape.size(), value_kind, value_value.put())); + auto values_ort_value = static_cast(value_value.get())->UseOrtValue(); + + auto hr = FillMapTensors::Run(ort_api, map_insp, keys_ort_value, values_ort_value); + RETURN_IF_FAILED(hr); + + OrtValue* inputs[2] = {keys_ort_value, values_ort_value}; + + OrtValue* map_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateValue(inputs, 2, ONNXType::ONNX_TYPE_MAP, &map_value), + ort_api); + auto unique_map_ort_value = UniqueOrtValue(map_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, engine, std::move(unique_map_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +static auto GetMapValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Double, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::String, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Float, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Double, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::String, _2); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::CreateMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + return GetMapValueCreator(this, key_kind, value_kind)(map, out); +} + +template +HRESULT CreateSequenceOfMapsValue(OnnxruntimeEngine* engine, IInspectable* sequence_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + auto ort_api = engine->UseOrtApi(); + auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); + + std::vector> element_values; + for (auto element : sequence) { + winrt::com_ptr element_value; + engine->CreateMapValue(reinterpret_cast(winrt::get_abi(element)), key_kind, value_kind, element_value.put()); + element_values.push_back(element_value); + } + + std::vector element_ort_values; + std::transform(element_values.begin(), + element_values.end(), + std::back_inserter(element_ort_values), + [](auto value) { return static_cast(value.get())->UseOrtValue(); }); + + OrtValue* sequence_value; + RETURN_HR_IF_NOT_OK_MSG( + ort_api->CreateValue(element_ort_values.data(), element_ort_values.size(), + ONNXType::ONNX_TYPE_SEQUENCE, &sequence_value), + ort_api); + auto unique_sequence_ort_value = UniqueOrtValue(sequence_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, engine, std::move(unique_sequence_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +static auto GetSequenceOfMapsValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateSequenceOfMapsValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateSequenceOfMapsValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::CreateSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + RETURN_IF_FAILED(GetSequenceOfMapsValueCreator(this, key_kind, value_kind)(sequence, out)); + return S_OK; +} + +template +static HRESULT FillAbiSequence(IInspectable* sequence_insp, std::vector<::winrt::Windows::Foundation::IInspectable>& elements) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); + for (auto element : elements) { + ::winrt::Windows::Foundation::Collections::IMap map_element; + element.as(map_element); + sequence.Append(map_element); + } + return S_OK; +} + +static auto GetAbiSequenceFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return &FillAbiSequence; + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return &FillAbiSequence; + } + THROW_HR(E_NOTIMPL); +} + +static winrt::Windows::Foundation::IInspectable CreateMap(winml::TensorKind key_kind, winml::TensorKind value_kind) { + winrt::Windows::Foundation::IInspectable map_insp; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + auto map = winrt::single_threaded_map(); + map.as(map_insp); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + auto map = winrt::single_threaded_map(); + map.as(map_insp); + } + + return map_insp; +} + +HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* sequence_value) { + auto ort_api = engine_factory_->UseOrtApi(); + auto onnxruntime_squence_value = static_cast(sequence_value); + auto ort_sequence_value = onnxruntime_squence_value->UseOrtValue(); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), ort_api); // This should not be freed as this owned by ort + + size_t num_elements; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueCount(ort_sequence_value, &num_elements), ort_api); + + // get the elements + std::vector<::winrt::Windows::Foundation::IInspectable> element_map_inspectables; + for (int index = 0; index < num_elements; index++) { + OrtValue* elements_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_sequence_value, index, ort_allocator, &elements_ort_value), ort_api); + auto unique_element_value = UniqueOrtValue(elements_ort_value, ort_api->ReleaseValue); + + winrt::com_ptr element_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(element_value.put(), this, std::move(unique_element_value), UniqueOrtAllocator(nullptr, nullptr))); + + ::winrt::Windows::Foundation::IInspectable map_inspectable = CreateMap(key_kind, value_kind); + RETURN_IF_FAILED(FillFromMapValue(reinterpret_cast(winrt::get_abi(map_inspectable)), key_kind, value_kind, element_value.get())); + element_map_inspectables.push_back(map_inspectable); + } + + GetAbiSequenceFiller(key_kind, value_kind)(sequence, element_map_inspectables); + return S_OK; +} + +HRESULT OnnxruntimeEngine::CreateOneInputAcrossDevices(const char* name, IValue* src, IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + auto src_value = static_cast(src); + + bool is_set; + auto is_empty = SUCCEEDED(src_value->IsEmpty(&is_set)) && is_set; + auto is_tensor = SUCCEEDED(src_value->IsTensor(&is_set)) && is_set; + + if (is_tensor && !is_empty) { + int16_t source_location; + int16_t input_required_location; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(src_value->UseOrtValue(), &source_location), + ort_api); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetInputRequiredDeviceId(session_.get(), name, &input_required_location), + ort_api); + + if (source_location != input_required_location) { + OrtValue* dest_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionCopyOneInputAcrossDevices(session_.get(), name, + src_value->UseOrtValue(), &dest_ort_value), + ort_api); + auto unique_dest_ort_value = UniqueOrtValue(dest_ort_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_dest_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; + } + } + + *out = src; + (*out)->AddRef(); + return S_OK; +} + +HRESULT OnnxruntimeEngine::Run(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) { + auto ort_api = engine_factory_->UseOrtApi(); + + OrtRunOptions* run_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateRunOptions(&run_options), + ort_api); + auto unique_run_options = UniqueOrtRunOptions(run_options, ort_api->ReleaseRunOptions); + + std::vector input_ort_values; + std::transform( + inputs, + inputs + num_inputs, + std::back_inserter(input_ort_values), + [&](auto& input) { + auto input_value = static_cast(input); + return input_value->UseOrtValue(); + }); + + std::vector output_ort_values; + std::transform( + outputs, + outputs + num_outputs, + std::back_inserter(output_ort_values), + [&](auto& output) { + auto output_value = static_cast(output); + return output_value->UseOrtValue(); + }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->Run(session_.get(), + unique_run_options.get(), + input_names, + input_ort_values.data(), + num_inputs, + output_names, + num_outputs, + output_ort_values.data()), + ort_api); + + for (size_t index = 0; index < num_outputs; index++) { + auto output_value = static_cast(outputs[index]); + if (output_value->UseOrtValue() != output_ort_values[index]) { + RETURN_IF_FAILED(output_value->AssignOrtValue(output_ort_values[index])); + } + } + + return S_OK; +} + +template +HRESULT FillAbiMap(IInspectable* map_insp, size_t num_elements, void* keys_data, void* values_data) { + auto map = CastToWinrtMap(map_insp); + + auto keys = reinterpret_cast::ResourceType*>(keys_data); + auto values = reinterpret_cast::ResourceType*>(values_data); + + for (auto i = 0; i < num_elements; ++i) { + map.Insert( + ResourceTypeToCppwinrtType(keys[i]), + ResourceTypeToCppwinrtType(values[i])); + } + return S_OK; +} + +static auto GetAbiMapFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::FillFromMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* map_value) { + auto ort_api = engine_factory_->UseOrtApi(); + auto onnxruntime_map_value = static_cast(map_value); + auto ort_map_value = onnxruntime_map_value->UseOrtValue(); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), + ort_api); // This should not be freed as this owned by ort + + // get the keys + OrtValue* keys_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 0, ort_allocator, &keys_ort_value), + ort_api); + auto unique_keys_value = UniqueOrtValue(keys_ort_value, ort_api->ReleaseValue); + winrt::com_ptr keys_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(keys_value.put(), this, std::move(unique_keys_value), UniqueOrtAllocator(nullptr, nullptr))); + + // get the keys + OrtValue* values_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 1, ort_allocator, &values_ort_value), + ort_api); + auto unique_values_value = UniqueOrtValue(values_ort_value, ort_api->ReleaseValue); + winrt::com_ptr values_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(values_value.put(), this, std::move(unique_values_value), UniqueOrtAllocator(nullptr, nullptr))); + + std::vector keys_shape; + keys_value->GetTensorShape(keys_shape); + + WinML::Resource keys_data; + RETURN_IF_FAILED(keys_value->GetResource(keys_data)); + WinML::Resource values_data; + RETURN_IF_FAILED(values_value->GetResource(values_data)); + + auto num_elements = ShapeSize(keys_shape.data(), keys_shape.size()); + GetAbiMapFiller(key_kind, value_kind)(map, num_elements, keys_data.get(), values_data.get()); + + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::RuntimeClassInitialize() { + ort_api_ = GetVersionedOrtApi(); + winml_adapter_api_ = GetVersionedWinmlAdapterApi(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::EnsureEnvironment() { + if (environment_ == nullptr) { + std::lock_guard lock(mutex_); + if (environment_ == nullptr) { + environment_ = PheonixSingleton(ort_api_); + } + } + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + + OrtModel* ort_model = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateModelFromPath(model_path, len, &ort_model), + ort_api_); + + auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + OrtModel* ort_model = nullptr; + if (auto status = winml_adapter_api_->CreateModelFromData(data, size, &ort_model)) { + return E_INVALIDARG; + } + + auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ Windows::AI::MachineLearning::IEngineBuilder** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + Microsoft::WRL::ComPtr onnxruntime_engine_builder; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_builder, this)); + RETURN_IF_FAILED(onnxruntime_engine_builder.CopyTo(out)); + return S_OK; +} + +const OrtApi* OnnxruntimeEngineFactory::UseOrtApi() { + return ort_api_; +} + +const WinmlAdapterApi* OnnxruntimeEngineFactory::UseWinmlAdapterApi() { + return winml_adapter_api_; +} + +HRESULT OnnxruntimeEngineFactory::GetOrtEnvironment(OrtEnv** ort_env) { + RETURN_IF_FAILED(EnsureEnvironment()); + RETURN_IF_FAILED(environment_->GetOrtEnvironment(ort_env)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::EnableDebugOutput(bool is_enabled) { + RETURN_IF_FAILED(EnsureEnvironment()); + RETURN_IF_FAILED(environment_->EnableDebugOutput(is_enabled)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::CreateCustomRegistry(IMLOperatorRegistry** registry) { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateCustomRegistry(registry), + ort_api_); + return S_OK; +} + +STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory) { + Microsoft::WRL::ComPtr onnxruntime_engine_factory; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_factory)); + RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory)); + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h new file mode 100644 index 0000000000000..6cb940c3a22a9 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -0,0 +1,143 @@ +#include "iengine.h" + +#include + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineBuilder; +class OnnxruntimeEngineFactory; +class OnnxruntimeEnvironment; +class OnnxruntimeModel; +class OnnxruntimeEngine; + +struct IOrtSessionBuilder; + +class OnnxruntimeValue : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IValue> { + public: + OnnxruntimeValue(); + ~OnnxruntimeValue(); + + HRESULT RuntimeClassInitialize(OnnxruntimeEngine* engine, UniqueOrtValue&& value, UniqueOrtAllocator&& allocator); + + STDMETHOD(IsEmpty) + (bool* out) override; + STDMETHOD(IsCpu) + (bool* out) override; + STDMETHOD(GetResource) + (WinML::Resource& resource) override; + STDMETHOD(IsTensor) + (bool* out) override; + STDMETHOD(IsOfTensorType) + (winml::TensorKind kind, bool* out) override; + STDMETHOD(GetTensorShape) + (std::vector& shape_vector) override; + STDMETHOD(IsOfMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) override; + STDMETHOD(IsOfVectorMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) override; + + HRESULT(SetParameter) + (IUnknown* param); + OrtValue* UseOrtValue(); + HRESULT AssignOrtValue(OrtValue* ptr); + + private: + Microsoft::WRL::ComPtr engine_; + Microsoft::WRL::ComPtr param_; + UniqueOrtValue value_; + UniqueOrtAllocator allocator_; +}; + +class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngine> { + public: + OnnxruntimeEngine(); + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, UniqueOrtSession&& session, IOrtSessionBuilder* session_builder); + + STDMETHOD(LoadModel) + (_In_ IModel* model) override; + STDMETHOD(Initialize) + () override; + STDMETHOD(RegisterGraphTransformers) + () override; + STDMETHOD(RegisterCustomRegistry) + (IMLOperatorRegistry* registry) override; + STDMETHOD(EndProfiling) + () override; + STDMETHOD(StartProfiling) + () override; + STDMETHOD(FlushContext) + () override; + STDMETHOD(TrimUploadHeap) + () override; + STDMETHOD(ReleaseCompletedReferences) + () override; + STDMETHOD(Sync) + () override; + STDMETHOD(CreateTensorValue) + (const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateTensorValueFromExternalD3DResource) + (ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateTensorValueFromExternalBuffer) + (void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateStringTensorValueFromDataWithCopy) + (const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) override; + STDMETHOD(CreateNullValue) + (_Out_ IValue** out) override; + STDMETHOD(CreateMapValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) override; + STDMETHOD(CreateSequenceOfMapsValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) override; + STDMETHOD(CreateOneInputAcrossDevices) + (const char* name, IValue* src, IValue** dest) override; + STDMETHOD(CopyValueAcrossDevices) + (IValue* src, IValue* dest) override; + STDMETHOD(Run) + (const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) override; + STDMETHOD(FillFromMapValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* value) override; + STDMETHOD(FillSequenceOfMapsValue) + (IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* value) override; + + OrtSession* UseOrtSession(); + const OrtApi* UseOrtApi(); + OnnxruntimeEngineFactory* GetEngineFactory(); + + private: + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr session_builder_; + UniqueOrtSession session_; +}; + +class OnnxruntimeEngineFactory : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngineFactory> { + public: + HRESULT RuntimeClassInitialize(); + STDMETHOD(CreateModel) + (_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) override; + STDMETHOD(CreateModel) + (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) override; + STDMETHOD(CreateEngineBuilder) + (IEngineBuilder** engine_builder) override; + STDMETHOD(EnableDebugOutput) + (bool is_enabled) override; + STDMETHOD(CreateCustomRegistry) + (_Out_ IMLOperatorRegistry** registry) override; + + const OrtApi* UseOrtApi(); + const WinmlAdapterApi* UseWinmlAdapterApi(); + HRESULT EnsureEnvironment(); + HRESULT GetOrtEnvironment(_Out_ OrtEnv** ort_env); + + private: + const OrtApi* ort_api_ = nullptr; + const WinmlAdapterApi* winml_adapter_api_ = nullptr; + std::shared_ptr environment_; + std::mutex mutex_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp new file mode 100644 index 0000000000000..ecfb6561657c9 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp @@ -0,0 +1,72 @@ +#include "pch.h" + +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeEngineBuilder.h" +#include "OnnxruntimeCpuSessionBuilder.h" + +#ifdef USE_DML +#include "OnnxruntimeDmlSessionBuilder.h" +#endif + +#include "OnnxruntimeErrors.h" +using namespace WinML; + +HRESULT OnnxruntimeEngineBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { + engine_factory_ = engine_factory; + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(Windows::AI::MachineLearning::IEngine** out) { + auto ort_api = engine_factory_->UseOrtApi(); + + Microsoft::WRL::ComPtr onnxruntime_session_builder; + + if (device_ == nullptr) { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_session_builder, engine_factory_.Get())); + } else { +#ifdef USE_DML + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_session_builder, engine_factory_.Get(), device_.Get(), queue_.Get())); +#endif + } + + OrtSessionOptions* ort_options; + RETURN_IF_FAILED(onnxruntime_session_builder->CreateSessionOptions(&ort_options)); + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + if (batch_size_override_.has_value()) { + constexpr const char* DATA_BATCH = "DATA_BATCH"; + RETURN_HR_IF_NOT_OK_MSG(ort_api->AddFreeDimensionOverride(session_options.get(), DATA_BATCH, batch_size_override_.value()), + ort_api); + } + + OrtSession* ort_session = nullptr; + onnxruntime_session_builder->CreateSession(session_options.get(), &ort_session); + auto session = UniqueOrtSession(ort_session, ort_api->ReleaseSession); + + Microsoft::WRL::ComPtr onnxruntime_engine; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine, + engine_factory_.Get(), std::move(session), onnxruntime_session_builder.Get())); + RETURN_IF_FAILED(onnxruntime_engine.CopyTo(out)); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::GetD3D12Device(ID3D12Device** device) { + *device = device_.Get(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::SetD3D12Resources(ID3D12Device* device, ID3D12CommandQueue* queue) { + device_ = device; + queue_ = queue; + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::GetID3D12CommandQueue(ID3D12CommandQueue** queue) { + *queue = queue_.Get(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::SetBatchSizeOverride(uint32_t batch_size_override) { + batch_size_override_ = batch_size_override; + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h new file mode 100644 index 0000000000000..34e68ae742ba0 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h @@ -0,0 +1,33 @@ +#include "iengine.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngineBuilder> { + public: + HRESULT RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine); + + STDMETHOD(SetD3D12Resources) + (ID3D12Device* device, ID3D12CommandQueue* queue); + + STDMETHOD(GetD3D12Device) + (_Outptr_ ID3D12Device** device); + + STDMETHOD(GetID3D12CommandQueue) + (_Outptr_ ID3D12CommandQueue** queue); + + STDMETHOD(SetBatchSizeOverride) + (uint32_t batch_size_override); + + STDMETHOD(CreateEngine) + (_Outptr_ IEngine** out); + + private: + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr device_ = nullptr; + Microsoft::WRL::ComPtr queue_ = nullptr; + std::optional batch_size_override_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp new file mode 100644 index 0000000000000..fbd5003b6d007 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" +#include "OnnxruntimeEnvironment.h" +#include "OnnxruntimeErrors.h" +#include "core/platform/windows/TraceLoggingConfig.h" +#include + +using namespace Windows::AI ::MachineLearning; + +static bool debug_output_ = false; + +static void WinmlOrtLoggingCallback(void* param, OrtLoggingLevel severity, const char* category, + const char* logger_id, const char* code_location, const char* message) { + UNREFERENCED_PARAMETER(param); + UNREFERENCED_PARAMETER(logger_id); + // ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry. + switch (severity) { + case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_WARNING), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE: + __fallthrough; //Default is Verbose too. + default: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + } + + if (debug_output_) { + OutputDebugStringA((std::string(message) + "\r\n").c_str()); + } +} + +static void WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) { + if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) { + TraceLoggingWrite( + winml_trace_logging_provider, + "OnnxRuntimeProfiling", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(profiler_record->category_name_, "Category"), + TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), + TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), + TraceLoggingString(profiler_record->event_name_, "Event Name"), + TraceLoggingInt32(profiler_record->process_id_, "Process ID"), + TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"), + TraceLoggingString(profiler_record->op_name_, "Operator Name"), + TraceLoggingString(profiler_record->execution_provider_, "Execution Provider")); + } else { + TraceLoggingWrite( + winml_trace_logging_provider, + "OnnxRuntimeProfiling", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(profiler_record->category_name_, "Category"), + TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), + TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), + TraceLoggingString(profiler_record->event_name_, "Event Name"), + TraceLoggingInt32(profiler_record->process_id_, "Process ID"), + TraceLoggingInt32(profiler_record->thread_id_, "Thread ID")); + } +} + +OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) { + OrtEnv* ort_env = nullptr; + THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), + ort_api); + ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv); + + // Configure the environment with the winml logger + auto winml_adapter_api = OrtGetWinMLAdapter(ort_api); + THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(), + &WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), + ort_api); + + THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api); +} + +HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) { + *ort_env = ort_env_.get(); + return S_OK; +} + +HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) { + debug_output_ = is_enabled; + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h new file mode 100644 index 0000000000000..c0e01f1989b99 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#pragma warning(push) +#pragma warning(disable : 4505) + +namespace Windows::AI ::MachineLearning { + +using UniqueOrtEnv = std::unique_ptr; + +class OnnxruntimeEnvironment { + public: + OnnxruntimeEnvironment(const OrtApi* ort_api); + + HRESULT GetOrtEnvironment(_Out_ OrtEnv** ert_env); + HRESULT EnableDebugOutput(bool is_enabled); + + private: + UniqueOrtEnv ort_env_; +}; + +} // namespace Windows::AI::MachineLearning + +#pragma warning(pop) \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeErrors.h b/winml/lib/Api.Ort/OnnxruntimeErrors.h new file mode 100644 index 0000000000000..3f9fd88b783d6 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeErrors.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "pch.h" +#include "core/providers/winml/winml_provider_factory.h" + +#ifdef _WIN32 +inline HRESULT OrtErrorCodeToHRESULT(OrtErrorCode status) noexcept { + switch (status) { + case OrtErrorCode::ORT_OK: + return S_OK; + case OrtErrorCode::ORT_FAIL: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_ARGUMENT: + return E_INVALIDARG; + case OrtErrorCode::ORT_NO_SUCHFILE: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_NO_MODEL: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_ENGINE_ERROR: + return E_FAIL; + case OrtErrorCode::ORT_RUNTIME_EXCEPTION: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_PROTOBUF: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_MODEL_LOADED: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case OrtErrorCode::ORT_NOT_IMPLEMENTED: + return E_NOTIMPL; + case OrtErrorCode::ORT_INVALID_GRAPH: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_EP_FAIL: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +#define RETURN_HR_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + RETURN_HR_MSG(hresult, \ + error_message); \ + } \ + } while (0) + +#define THROW_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + winrt::hstring errorMessage(WinML::Strings::HStringFromUTF8(error_message)); \ + throw winrt::hresult_error(hresult, errorMessage); \ + } \ + } while (0) diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp new file mode 100644 index 0000000000000..bc782fbd17343 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" +#include "OnnxruntimeModel.h" +#include "core/platform/windows/TraceLoggingConfig.h" +#include + +#include "OnnxruntimeDescriptorConverter.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" + +using namespace Windows::AI::MachineLearning; + +struct winml_adapter_api_model_feature_helper { + decltype(WinmlAdapterApi::ModelGetInputCount) GetCount; + decltype(WinmlAdapterApi::ModelGetInputName) GetName; + decltype(WinmlAdapterApi::ModelGetInputDescription) GetDescription; + decltype(WinmlAdapterApi::ModelGetInputTypeInfo) GetTypeInfo; +}; + +HRESULT CreateFeatureDescriptors( + OnnxruntimeEngineFactory* engine_factory, + const winml_adapter_api_model_feature_helper* feature_helpers, + OrtModel* ort_model, + std::vector& descriptors) { + const auto ort_api = engine_factory->UseOrtApi(); + size_t count; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetCount(ort_model, &count), + engine_factory->UseOrtApi()); + + for (size_t i = 0; i < count; i++) { + OnnxruntimeValueInfoWrapper descriptor; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetName(ort_model, i, &descriptor.name_, &descriptor.name_length_), + engine_factory->UseOrtApi()); + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetDescription(ort_model, i, &descriptor.description_, &descriptor.description_length_), + engine_factory->UseOrtApi()); + + OrtTypeInfo* type_info; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetTypeInfo(ort_model, i, &type_info), + engine_factory->UseOrtApi()); + + descriptor.type_info_ = UniqueOrtTypeInfo(type_info, ort_api->ReleaseTypeInfo); + + descriptors.push_back(std::move(descriptor)); + } + return S_OK; +} + +HRESULT ModelInfo::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, OrtModel* ort_model) { + RETURN_HR_IF_NULL(E_INVALIDARG, ort_model); + + const auto winml_adapter_api = engine_factory->UseWinmlAdapterApi(); + + // Get Metadata + size_t count; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadataCount(ort_model, &count), + engine_factory->UseOrtApi()); + + const char* metadata_key; + size_t metadata_key_len; + const char* metadata_value; + size_t metadata_value_len; + for (size_t i = 0; i < count; i++) { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadata(ort_model, i, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len), + engine_factory->UseOrtApi()); + + model_metadata_.insert_or_assign( + std::string(metadata_key, metadata_key_len), + std::string(metadata_value, metadata_value_len)); + } + + WinML::OnnxruntimeDescriptorConverter converter(engine_factory, model_metadata_); + + static const winml_adapter_api_model_feature_helper input_helpers = { + winml_adapter_api->ModelGetInputCount, + winml_adapter_api->ModelGetInputName, + winml_adapter_api->ModelGetInputDescription, + winml_adapter_api->ModelGetInputTypeInfo}; + + // Create inputs + std::vector inputs; + RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &input_helpers, ort_model, inputs)); + input_features_ = converter.ConvertToLearningModelDescriptors(inputs); + + // Create outputs + static const winml_adapter_api_model_feature_helper output_helpers = { + winml_adapter_api->ModelGetOutputCount, + winml_adapter_api->ModelGetOutputName, + winml_adapter_api->ModelGetOutputDescription, + winml_adapter_api->ModelGetOutputTypeInfo}; + + std::vector outputs; + RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &output_helpers, ort_model, outputs)); + output_features_ = converter.ConvertToLearningModelDescriptors(outputs); + + const char* out; + size_t len; + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetAuthor(ort_model, &out, &len), + engine_factory->UseOrtApi()); + author_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetName(ort_model, &out, &len), + engine_factory->UseOrtApi()); + name_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDomain(ort_model, &out, &len), + engine_factory->UseOrtApi()); + domain_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDescription(ort_model, &out, &len), + engine_factory->UseOrtApi()); + description_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetVersion(ort_model, &version_), + engine_factory->UseOrtApi()); + + return S_OK; +} + +STDMETHODIMP ModelInfo::GetAuthor(const char** out, size_t* len) { + *out = author_.c_str(); + *len = author_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetName(const char** out, size_t* len) { + *out = name_.c_str(); + *len = name_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetDomain(const char** out, size_t* len) { + *out = domain_.c_str(); + *len = domain_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetDescription(const char** out, size_t* len) { + *out = description_.c_str(); + *len = description_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetVersion(int64_t* out) { + *out = version_; + return S_OK; +} + +STDMETHODIMP ModelInfo::GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView** metadata) { + std::unordered_map map_copy; + for (auto& pair : model_metadata_) { + auto metadata_key = WinML::Strings::HStringFromUTF8(pair.first); + auto metadata_value = WinML::Strings::HStringFromUTF8(pair.second); + map_copy.emplace(std::move(metadata_key), std::move(metadata_value)); + } + auto map = winrt::single_threaded_map(std::move(map_copy)); + winrt::copy_to_abi(map, *(void**)metadata); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView** features) { + *features = nullptr; + winrt::copy_to_abi(input_features_.GetView(), *(void**)features); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView** features) { + *features = nullptr; + winrt::copy_to_abi(output_features_.GetView(), *(void**)features); + return S_OK; +} + +OnnruntimeModel::OnnruntimeModel() : ort_model_(nullptr, nullptr) { +} + +STDMETHODIMP OnnruntimeModel::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, UniqueOrtModel&& ort_model) { + RETURN_HR_IF_NULL(E_INVALIDARG, ort_model); + + engine_factory_ = engine_factory; + ort_model_ = std::move(ort_model); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::GetModelInfo(IModelInfo** info) { + if (info_ == nullptr) { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&info_, engine_factory_.Get(), ort_model_.get())); + } + + info_.CopyTo(info); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::ModelEnsureNoFloat16() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelEnsureNoFloat16(ort_model_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::CloneModel(IModel** copy) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtModel* ort_model_copy; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CloneModel(ort_model_.get(), &ort_model_copy), + engine_factory_->UseOrtApi()); + + auto model = UniqueOrtModel(ort_model_copy, winml_adapter_api->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(copy, engine_factory_.Get(), std::move(model))); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::DetachOrtModel(OrtModel** model) { + *model = ort_model_.release(); + return S_OK; +} diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.h b/winml/lib/Api.Ort/OnnxruntimeModel.h new file mode 100644 index 0000000000000..1be587cfc8b48 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeModel.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "iengine.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +// The IOrtSessionBuilder offers an abstraction over the creation of +// InferenceSession, that enables the creation of the session based on a device (CPU/DML). +MIDL_INTERFACE("92679cbf-7a9d-48bb-b97f-ef9fb447ce8e") +IOnnxruntimeModel : IUnknown { + virtual HRESULT STDMETHODCALLTYPE DetachOrtModel(OrtModel * *model) PURE; +}; + +class ModelInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModelInfo> { + public: + HRESULT RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine, _In_ OrtModel* ort_model); + + STDMETHOD(GetAuthor) + (const char** out, size_t* len); + STDMETHOD(GetName) + (const char** out, size_t* len); + STDMETHOD(GetDomain) + (const char** out, size_t* len); + STDMETHOD(GetDescription) + (const char** out, size_t* len); + STDMETHOD(GetVersion) + (int64_t* out); + STDMETHOD(GetModelMetadata) + (ABI::Windows::Foundation::Collections::IMapView** metadata); + STDMETHOD(GetInputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView** features); + STDMETHOD(GetOutputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView** features); + + private: + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::unordered_map model_metadata_; + wfc::IVector input_features_; + wfc::IVector output_features_; +}; + +class OnnruntimeModel : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModel, + IOnnxruntimeModel> { + public: + OnnruntimeModel(); + + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine, UniqueOrtModel&& ort_model); + + STDMETHOD(GetModelInfo) + (IModelInfo** info); + STDMETHOD(ModelEnsureNoFloat16) + (); + STDMETHOD(CloneModel) + (IModel** copy); + STDMETHOD(DetachOrtModel) + (OrtModel** model); + + private: + UniqueOrtModel ort_model_; + + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr info_; + + std::optional> metadata_cache_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h new file mode 100644 index 0000000000000..372da3c792c9f --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning { + +// The IOrtSessionBuilder offers an abstraction over the creation of +// InferenceSession, that enables the creation of the session based on a device (CPU/DML). +MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") +IOrtSessionBuilder : IUnknown { + virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions * *options) = 0; + + virtual HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions * options, + OrtSession * *session) = 0; + + virtual HRESULT STDMETHODCALLTYPE Initialize( + OrtSession * session) = 0; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h new file mode 100644 index 0000000000000..120965f4a7e80 --- /dev/null +++ b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "iengine.h" + +STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory); \ No newline at end of file diff --git a/winml/lib/Api.Ort/pch.h b/winml/lib/Api.Ort/pch.h new file mode 100644 index 0000000000000..e41ad60623e9b --- /dev/null +++ b/winml/lib/Api.Ort/pch.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "winrt_headers.h" + +#include "core/providers/winml/winml_provider_factory.h" +#include "adapter/winml_adapter_c_api.h" + +using UniqueOrtModel = std::unique_ptr; +using UniqueOrtSessionOptions = std::unique_ptr; +using UniqueOrtSession = std::unique_ptr; +using UniqueOrtExecutionProvider = std::unique_ptr; +using UniqueOrtValue = std::unique_ptr; +using UniqueOrtMemoryInfo = std::unique_ptr; +using UniqueOrtTypeInfo = std::unique_ptr; +using UniqueOrtTensorTypeAndShapeInfo = std::unique_ptr; +using UniqueOrtAllocator = std::unique_ptr; +using UniqueOrtRunOptions = std::unique_ptr; diff --git a/winml/lib/Api/FeatureValues.h b/winml/lib/Api/FeatureValues.h index b621238768b84..637a4a2c61a74 100644 --- a/winml/lib/Api/FeatureValues.h +++ b/winml/lib/Api/FeatureValues.h @@ -58,7 +58,7 @@ \ type(std::vector const& shape) : Base(shape){}; \ \ - type(std::vector const& shape, ID3D12Resource* pResource, UINT64 resource_width) : Base(shape, pResource, resource_width){}; \ + type(std::vector const& shape, ID3D12Resource* pResource) : Base(shape, pResource){}; \ }; \ } \ namespace winrt::Windows::AI::MachineLearning::factory_implementation { \ @@ -85,7 +85,7 @@ CREATE_TENSOR(TensorUInt32Bit, uint32_t, uint32_t) CREATE_TENSOR(TensorInt32Bit, int32_t, int32_t) CREATE_TENSOR(TensorUInt64Bit, uint64_t, uint64_t) CREATE_TENSOR(TensorInt64Bit, int64_t, int64_t) -CREATE_TENSOR(TensorFloat16Bit, onnxruntime::MLFloat16, float) +CREATE_TENSOR(TensorFloat16Bit, WinML::Half, float) #pragma warning(push) #pragma warning(disable : 4702) // Unreachable code (one of TensorBase's constructor unconditionally throws for diff --git a/winml/lib/Api/ImageFeatureDescriptor.cpp b/winml/lib/Api/ImageFeatureDescriptor.cpp index ac53db7d6e7df..e2f1e70000512 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.cpp +++ b/winml/lib/Api/ImageFeatureDescriptor.cpp @@ -11,9 +11,9 @@ namespace winrt::Windows::AI::MachineLearning::implementation { ImageFeatureDescriptor::ImageFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, wgi::BitmapPixelFormat pixel_format, wgi::BitmapAlphaMode alpha_mode, uint32_t width, @@ -32,28 +32,6 @@ ImageFeatureDescriptor::ImageFeatureDescriptor( color_space_gamma_(color_space_gamma) { } -ImageFeatureDescriptor::ImageFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& TensorKind, - array_view Shape, - Windows::Graphics::Imaging::BitmapPixelFormat const& BitmapPixelFormat, - Windows::Graphics::Imaging::BitmapAlphaMode const& BitmapAlphaMode, - uint32_t Width, - uint32_t Height) : name_(Name), - description_(Description), - tensor_kind_(TensorKind), - shape_(Shape.begin(), Shape.end()), - is_required_(IsRequired), - pixel_format_(BitmapPixelFormat), - alpha_mode_(BitmapAlphaMode), - width_(Width), - height_(Height), - nominal_pixel_range_(ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255), - color_space_gamma_(ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB) { -} - wgi::BitmapPixelFormat ImageFeatureDescriptor::BitmapPixelFormat() try { return pixel_format_; diff --git a/winml/lib/Api/ImageFeatureDescriptor.h b/winml/lib/Api/ImageFeatureDescriptor.h index 336e4230dd9b4..54f1f265b3724 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.h +++ b/winml/lib/Api/ImageFeatureDescriptor.h @@ -24,9 +24,9 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, wgi::BitmapPixelFormat pixelformat, wgi::BitmapAlphaMode alphamode, uint32_t width, @@ -34,17 +34,6 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageNominalPixelRange nominalPixelRange, ImageColorSpaceGamma colorSpaceGamma); - ImageFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& TensorKind, - array_view Shape, - Windows::Graphics::Imaging::BitmapPixelFormat const& BitmapPixelFormat, - Windows::Graphics::Imaging::BitmapAlphaMode const& BitmapAlphaMode, - uint32_t Width, - uint32_t Height); - wgi::BitmapPixelFormat BitmapPixelFormat(); @@ -104,10 +93,4 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageNominalPixelRange nominal_pixel_range_; ImageColorSpaceGamma color_space_gamma_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct ImageFeatureDescriptor : ImageFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/ImageFeatureValue.cpp b/winml/lib/Api/ImageFeatureValue.cpp index f24d41489130f..983ee677471de 100644 --- a/winml/lib/Api/ImageFeatureValue.cpp +++ b/winml/lib/Api/ImageFeatureValue.cpp @@ -20,10 +20,6 @@ #include "D3DDeviceCache.h" #include "TensorFeatureDescriptor.h" -// Uncomment to enable DEBUG_IMAGE_TENSOR_RESOURCE and -// allow debugging the content of the resource -//#define DEBUG_IMAGE_TENSOR_RESOURCE - using namespace WinML; using namespace winrt::Windows::Graphics::Imaging; using namespace winrt::Windows::Graphics::DirectX::Direct3D11; @@ -38,87 +34,6 @@ struct ImageFeatureValue::ImageResourceMetadata { ::Windows::AI::MachineLearning::Internal::ImageTensorDescription TensorDescriptor; }; -#ifdef ENABLE_IMAGE_FEATURE_VALUE_TENSOR_DUMP -static void DumpResourceToCPU( - ID3D12Resource* pResource, - com_ptr spSession, - ImageTensorDescription tensorDescriptor, - ::Windows::AI::MachineLearning::Internal::TensorToVideoFrameConverter* tensorToImageConverter) { - auto spDevice = spSession->Device().as(); - auto spD3DDevice = spDevice->GetD3DDevice(); - auto spCommandQueue = spDevice->GetDeviceQueue(); - auto pProvider = spSession->GetExecutionProvider(); - - UINT64 bufferbytesize = pResource->GetDesc().Width; - - Dml::FlushContext(pProvider); - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_READBACK, - D3D12_CPU_PAGE_PROPERTY_UNKNOWN, - D3D12_MEMORY_POOL_UNKNOWN, - 0, - 0}; - D3D12_RESOURCE_DESC resourceDesc = { - D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - bufferbytesize, - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - {1, 0}, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_NONE}; - - ID3D12Resource* pCPUResource = nullptr; - spD3DDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_COPY_DEST, - nullptr, - IID_PPV_ARGS(&pCPUResource)); - - { - ScopedCommandList scopedCommandList(spSession); - // Record command list copy action - scopedCommandList.get()->CopyResource(pCPUResource, pResource); - scopedCommandList.get()->Close(); - ID3D12CommandList* pCommandLists[] = {scopedCommandList.get()}; - spCommandQueue->ExecuteCommandLists(ARRAYSIZE(pCommandLists), pCommandLists); - - // TODO: Do we need to set a fence here and wait for completion before - // reading the resource in cpu memory? - } - - D3D12_RANGE range = {0, static_cast(bufferbytesize)}; - - void* pData = nullptr; - pCPUResource->Map(0, &range, reinterpret_cast(&pData)); - - range.End = 0; - - DebugBreak(); - - SoftwareBitmap bitmap(BitmapPixelFormat::Bgra8, 720, 720); - Windows::Media::VideoFrame frame = Windows::Media::VideoFrame::CreateWithSoftwareBitmap(bitmap); - tensorToImageConverter->SoftwareTensorToVideoFrame( - spSession.as(), - reinterpret_cast(pData), - tensorDescriptor, - frame); - - auto folder = Windows::Storage::StorageFolder::GetFolderFromPathAsync(L"C:\\").get(); - auto imagefile = folder.CreateFileAsync(L"out.png", Windows::Storage::CreationCollisionOption::ReplaceExisting).get(); - auto stream = imagefile.OpenAsync(Windows::Storage::FileAccessMode::ReadWrite).get(); - auto encoder = BitmapEncoder::CreateAsync(BitmapEncoder::JpegEncoderId(), stream).get(); - encoder.SetSoftwareBitmap(frame.SoftwareBitmap()); - encoder.FlushAsync(); - pResource->Unmap(0, &range); -} -#endif - Windows::AI::MachineLearning::ImageFeatureValue ImageFeatureValue::Create( uint32_t batchSize, BitmapPixelFormat format, @@ -329,13 +244,12 @@ static void CPUTensorize( std::vector bounds, ImageTensorDescription tensorDescriptor, com_ptr spSession, - void* pResource, + BYTE* resource, unsigned int singleFrameBufferSize) { // Tensorize video frames one by one without extra copy. - BYTE* tempPResource = reinterpret_cast(pResource); for (uint32_t batchIdx = 0; batchIdx < videoFrames.Size(); ++batchIdx) { - CPUTensorize(videoFrames.GetAt(batchIdx), bounds[batchIdx], tensorDescriptor, spSession, tempPResource); - tempPResource += singleFrameBufferSize; + CPUTensorize(videoFrames.GetAt(batchIdx), bounds[batchIdx], tensorDescriptor, spSession, resource); + resource += singleFrameBufferSize; } } @@ -344,15 +258,8 @@ static void GPUTensorize( std::vector bounds, ImageTensorDescription tensorDescriptor, com_ptr spSession, - void* pAllocatedResource, + ID3D12Resource* d3dResource, WinML::BindingContext& context) { - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - - auto d3dResource = - adapter->GetD3D12ResourceFromAllocation( - spSession->GetExecutionProvider(), - pAllocatedResource); auto spDevice = spSession->Device().as(); ConverterResourceDescription descriptor = {}; @@ -386,9 +293,6 @@ static void GPUTensorize( context.converter = pooledConverter; } } -#ifdef DEBUG_IMAGE_TENSOR_RESOURCE - DumpResourceToCPU(d3dResource, spSession, tensorDescriptor); -#endif } std::optional ImageFeatureValue::GetInputMetadata(const WinML::BindingContext& context) { @@ -490,7 +394,7 @@ std::optional ImageFeatureValue::GetIn return ImageResourceMetadata{bounds, imageTensorDescriptor}; } -HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) try { +HRESULT ImageFeatureValue::GetValue(WinML::BindingContext& context, IValue** out) try { FAIL_FAST_IF(!(std::all_of(m_widths.begin(), m_widths.end(), [](int i) { return i != 0; }))); FAIL_FAST_IF(!(std::all_of(m_heights.begin(), m_heights.end(), [](int i) { return i != 0; }))); @@ -502,27 +406,19 @@ HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue* // Get the session auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); - auto provider = spSession->GetExecutionProvider(); - - // and the adapter - if (!m_adapter) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(m_adapter.put())); - } + auto engine = spSession->GetEngine(); // create the OrtValue - Ort::Allocator dml_allocator(m_adapter.get(), nullptr); - WINML_THROW_IF_FAILED(m_adapter->GetProviderAllocator(provider, dml_allocator.put())); - - // create the OrtValue as a tensor letting ort know that we own the data buffer - Ort::Value ort_tensor = Ort::Value::CreateTensor( - dml_allocator, - &(resourceMetadata.TensorDescriptor.sizes[0]), + winrt::com_ptr value; + RETURN_IF_FAILED(engine->CreateTensorValue( + resourceMetadata.TensorDescriptor.sizes, sizeof(resourceMetadata.TensorDescriptor.sizes) / sizeof(resourceMetadata.TensorDescriptor.sizes[0]), - (resourceMetadata.TensorDescriptor.dataType == kImageTensorDataTypeFloat32) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); + resourceMetadata.TensorDescriptor.dataType == kImageTensorDataTypeFloat32 ? winml::TensorKind::Float : winml::TensorKind::Float16, + value.put())); // Get the tensor raw data - void* pAllocatedResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_tensor, &pAllocatedResource)); + WinML::Resource void_resource; + RETURN_IF_FAILED(value->GetResource(void_resource)); if (context.type == BindingType::kInput) { // Only tensorize inputs @@ -530,15 +426,15 @@ HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue* auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize; auto singleFrameBufferSize = bufferByteSize / m_batchSize; if (spDevice->IsCpuDevice()) { - CPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, pAllocatedResource, static_cast(singleFrameBufferSize)); - } - else { - GPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, pAllocatedResource, context); + auto resource = reinterpret_cast(void_resource.get()); + CPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, resource, static_cast(singleFrameBufferSize)); + } else { + auto resource = reinterpret_cast(void_resource.get()); + GPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, resource, context); } } - *ort_value = ort_tensor.release(); - *ort_allocator = dml_allocator.release(); + *out = value.detach(); return S_OK; } WINML_CATCH_ALL_COM @@ -549,18 +445,14 @@ HRESULT ImageFeatureValue::IsPlaceholder(bool* pIsPlaceHolder) { return S_OK; } -HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, OrtValue* ort_value) try { +HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, IValue* value) try { // Get the device auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); - if (!m_adapter) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(m_adapter.put())); - } - // Get the output tensor raw data - void* pAllocatedResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_value, &pAllocatedResource)); + WinML::Resource void_resource; + RETURN_IF_FAILED(value->GetResource(void_resource)); // Get the run context auto metadata = GetInputMetadata(context); @@ -570,36 +462,30 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, Ort descriptor.width = static_cast(resourceMetadata.TensorDescriptor.sizes[3]); descriptor.height = static_cast(resourceMetadata.TensorDescriptor.sizes[2]); - Ort::MemoryInfo memory_info(nullptr); - m_adapter->GetValueMemoryInfo(ort_value, memory_info.put()); - - if (!strcmp(memory_info.Name(), onnxruntime::CPU) || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUOutput || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUInput) { + bool out; + if (SUCCEEDED(value->IsCpu(&out)) && out) { descriptor.pixel_format = static_cast(BitmapPixelFormat::Bgra8); descriptor.luid = {}; // Converted image on CPU auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); - auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast< int64_t>(1), std::multiplies()); + auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast(1), std::multiplies()); auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize / m_batchSize; - BYTE* tempPAllocatedResource = reinterpret_cast(pAllocatedResource); + BYTE* resource = reinterpret_cast(void_resource.get()); for (uint32_t batchIdx = 0; batchIdx < m_batchSize; ++batchIdx) { // Convert Software Tensor to VideoFrame one by one based on the buffer size. auto videoFrame = m_videoFrames.GetAt(batchIdx); - pooledConverter->Get()->Detensorizer->SoftwareTensorToVideoFrame(context.session, tempPAllocatedResource, resourceMetadata.TensorDescriptor, videoFrame); - tempPAllocatedResource += bufferByteSize; + pooledConverter->Get()->Detensorizer->SoftwareTensorToVideoFrame(context.session, resource, resourceMetadata.TensorDescriptor, videoFrame); + resource += bufferByteSize; } - } - else { + } else { descriptor.pixel_format = static_cast(DirectXPixelFormat::B8G8R8X8UIntNormalized); descriptor.luid = spDevice->GetD3DDevice()->GetAdapterLuid(); // Converted image on GPU auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); - auto pProvider = spSession->GetExecutionProvider(); - auto d3dResource = m_adapter->GetD3D12ResourceFromAllocation(pProvider, pAllocatedResource); + auto d3dResource = reinterpret_cast(void_resource.get()); for (uint32_t batchIdx = 0; batchIdx < m_batchSize; ++batchIdx) { auto videoFrame = m_videoFrames.GetAt(batchIdx); @@ -614,9 +500,6 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, Ort spDevice->GetD3DDeviceCache()->SyncD3D12ToCPU(); pooledConverter->Get()->Detensorizer->ResetAllocator(); } -#ifdef DEBUG_IMAGE_TENSOR_RESOURCE - DumpResourceToCPU(d3dResource, spSession, resourceInfo.Metadata.TensorDescriptor); -#endif } // Release any converters back to the pool by nulling out the wrapper. diff --git a/winml/lib/Api/ImageFeatureValue.h b/winml/lib/Api/ImageFeatureValue.h index d826c12e231ba..c135d2fee4a3d 100644 --- a/winml/lib/Api/ImageFeatureValue.h +++ b/winml/lib/Api/ImageFeatureValue.h @@ -32,20 +32,20 @@ struct ImageFeatureValue : ImageFeatureValueT GetInputMetadata(const WinML::BindingContext& context); // ILotusValueProviderPrivate implementation - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator); + STDMETHOD(GetValue) + (WinML::BindingContext& context, WinML::IValue** out); STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder); STDMETHOD(UpdateSourceResourceData) - (WinML::BindingContext& context, OrtValue* ort_value); + (WinML::BindingContext& context, WinML::IValue* value); STDMETHOD(AbiRepresentation) (winrt::Windows::Foundation::IInspectable& abiRepresentation); std::vector Widths() { return m_widths; } std::vector Heights() { return m_heights; } bool IsBatch() { return m_batchSize > 1; } + private: - com_ptr m_adapter; winrt::Windows::Foundation::Collections::IVector m_videoFrames; std::vector m_widths = {}; std::vector m_heights = {}; diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 51c8a02e04e2e..13b685b963ad5 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -10,6 +10,10 @@ #include "SequenceFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" +#include "OnnxruntimeProvider.h" + +#include + namespace winrt::Windows::AI::MachineLearning::implementation { LearningModel::LearningModel( const hstring& path, @@ -22,70 +26,96 @@ LearningModel::LearningModel( const std::string& path, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); - WINML_THROW_IF_FAILED(adapter_->CreateModelProto(path.c_str(), model_proto_.put())); - - Initialize(); + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(engine_factory_->CreateModel(path.c_str(), path.size(), model_.put())); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL +static HRESULT CreateModelFromStream( + WinML::IEngineFactory* engine_factory, + const wss::IRandomAccessStreamReference stream, + WinML::IModel** model) { + auto content = stream.OpenReadAsync().get(); + + wss::Buffer buffer(static_cast(content.Size())); + auto result = content.ReadAsync( + buffer, + buffer.Capacity(), + wss::InputStreamOptions::None) + .get(); + + auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); + WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid."); + + void* data; + WINML_THROW_IF_FAILED_MSG(bytes->Buffer(reinterpret_cast(&data)), "Failed to acquire buffer from model stream."); + + size_t len = static_cast(content.Size()); + WINML_THROW_IF_FAILED(engine_factory->CreateModel(data, len, model)); + + return S_OK; +} + LearningModel::LearningModel( const wss::IRandomAccessStreamReference stream, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); - WINML_THROW_IF_FAILED(adapter_->CreateModelProto( - static_cast(winrt::get_abi(stream)), - model_proto_.put())); - - Initialize(); + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(CreateModelFromStream(engine_factory_.get(), stream, model_.put())); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL -void LearningModel::Initialize() { - WINML_THROW_IF_FAILED(adapter_->CreateModelInfo(model_proto_.get(), model_info_.put())); -} - hstring LearningModel::Author() try { - return WinML::Strings::HStringFromUTF8(model_info_->author()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetAuthor(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Name() try { - return WinML::Strings::HStringFromUTF8( - model_info_->name()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetName(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Domain() try { - return WinML::Strings::HStringFromUTF8( - model_info_->domain()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetDomain(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Description() try { - return WinML::Strings::HStringFromUTF8( - model_info_->description()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetDescription(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL int64_t LearningModel::Version() try { - return model_info_->version(); + int64_t version; + WINML_THROW_IF_FAILED(model_info_->GetVersion(&version)); + return version; } WINML_CATCH_ALL wfc::IMapView LearningModel::Metadata() try { - ABI::Windows::Foundation::Collections::IMapView* metadata; + ABI::Windows::Foundation::Collections::IMapView* metadata = nullptr; wfc::IMapView out; WINML_THROW_IF_FAILED(model_info_->GetModelMetadata(&metadata)); winrt::attach_abi(out, metadata); @@ -104,13 +134,14 @@ LearningModel::GetOperatorRegistry() { operator_provider_.as(); IMLOperatorRegistry* registry = nullptr; - WINML_THROW_IF_FAILED(adapter_->GetOperatorRegistry(operator_provider_native.get(), ®istry)); + // Retrieve the "operator abi" registry. + THROW_IF_FAILED(operator_provider_native->GetRegistry(®istry)); return registry; } wfc::IVectorView LearningModel::InputFeatures() try { - ABI::Windows::Foundation::Collections::IVectorView* features; + ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetInputFeatures(&features)); winrt::attach_abi(out, features); @@ -120,7 +151,7 @@ WINML_CATCH_ALL wfc::IVectorView LearningModel::OutputFeatures() try { - ABI::Windows::Foundation::Collections::IVectorView* features; + ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetOutputFeatures(&features)); winrt::attach_abi(out, features); @@ -130,12 +161,12 @@ WINML_CATCH_ALL void LearningModel::Close() try { // close the model - model_proto_ = nullptr; + model_ = nullptr; } WINML_CATCH_ALL bool LearningModel::IsDisposed() { - return model_proto_ == nullptr; + return model_ == nullptr; } wf::IAsyncOperation @@ -196,30 +227,33 @@ LearningModel::LoadFromStream( } WINML_CATCH_ALL -winmla::IModelProto* -LearningModel::DetachModelProto() { - com_ptr detached_model_proto; - if (model_proto_ != nullptr) { - detached_model_proto.attach(model_proto_.detach()); +WinML::IModel* +LearningModel::DetachModel() { + com_ptr detached_model; + if (model_ != nullptr) { + detached_model.attach(model_.detach()); // Close the model since we now own the model proto Close(); } - return detached_model_proto.detach(); + return detached_model.detach(); } -winmla::IModelProto* -LearningModel::CopyModelProto() { - if (model_proto_ == nullptr) { +WinML::IModel* +LearningModel::CloneModel() { + if (model_ == nullptr) { return nullptr; } - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - com_ptr model_proto; - WINML_THROW_IF_FAILED(adapter->CreateModelProto(model_proto_.get(), model_proto.put())); + com_ptr model_copy; + WINML_THROW_IF_FAILED(model_->CloneModel(model_copy.put())); + + return model_copy.detach(); +} - return model_proto.detach(); +WinML::IEngineFactory* +LearningModel::GetEngineFactory() { + return engine_factory_.get(); } } // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index 261dc4b8655fa..e00eb6339824a 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -4,7 +4,12 @@ #pragma once #include "LearningModel.g.h" -#include "core/providers/winml/winml_provider_factory.h" + +namespace Windows::AI::MachineLearning { +struct IEngineFactory; +struct IModel; +struct IModelInfo; +} // namespace Windows::AI::MachineLearning namespace winrt::Windows::AI::MachineLearning::implementation { @@ -93,20 +98,15 @@ struct LearningModel : LearningModelT { /* Non-ABI methods */ bool IsDisposed(); IMLOperatorRegistry* GetOperatorRegistry(); - winmla::IModelProto* DetachModelProto(); - winmla::IModelProto* CopyModelProto(); + WinML::IModel* DetachModel(); + WinML::IModel* CloneModel(); + WinML::IEngineFactory* GetEngineFactory(); private: - void Initialize(); - void LogCreationEvent(bool fromStream = false); - void ModelUseFP16( - winml::ILearningModelFeatureDescriptor descriptor, - bool& use_fp16); + com_ptr engine_factory_; + com_ptr model_; + com_ptr model_info_; - private: - com_ptr adapter_; - com_ptr model_proto_; - com_ptr model_info_; ILearningModelOperatorProvider operator_provider_; }; diff --git a/winml/lib/Api/LearningModelBinding.cpp b/winml/lib/Api/LearningModelBinding.cpp index 8076f2fb008a1..65a592a3fbef3 100644 --- a/winml/lib/Api/LearningModelBinding.cpp +++ b/winml/lib/Api/LearningModelBinding.cpp @@ -17,7 +17,6 @@ namespace winrt::Windows::AI::MachineLearning::implementation { LearningModelBinding::LearningModelBinding( Windows::AI::MachineLearning::LearningModelSession const& session) try : m_session(session) { session.as()->CheckClosed(); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); } WINML_CATCH_ALL @@ -39,10 +38,6 @@ static Windows::AI::MachineLearning::ILearningModelFeatureDescriptor FindValidBi return nullptr; } -LearningModelBinding::~LearningModelBinding() { - Clear(); -} - using NullableBindingPort = std::optional>; static NullableBindingPort FindValidBinding( @@ -63,7 +58,7 @@ void LearningModelBinding::CacheProvider( m_providers[name] = providerInfo; } -std::tuple LearningModelBinding::CreateBinding( +std::tuple, BindingType> LearningModelBinding::CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& inspectable, Windows::Foundation::Collections::IPropertySet const& properties) { @@ -102,10 +97,9 @@ std::tuple LearningModelBind }; // Get the bound tensor - Ort::Value value(nullptr); - Ort::Allocator ort_allocator(adapter_.get(), nullptr); + winrt::com_ptr value; - // Get the native ORT interface for the given bind value + // Get the native interface for the given bind value auto spLotusValueProvider = featureValue.as(); auto spSession = m_session.as(); @@ -126,7 +120,7 @@ std::tuple LearningModelBind if (!isPlaceHolder || shouldAlwaysTensorize) { // If not a placeholder, attempt to get the underlying resource WINML_THROW_IF_FAILED_MSG( - spLotusValueProvider->GetOrtValue(context, value.put(), ort_allocator.put()), + spLotusValueProvider->GetValue(context, value.put()), "The model variable %s failed tensorization.", name.c_str()); } else { @@ -135,13 +129,15 @@ std::tuple LearningModelBind isPlaceHolder && bindingType == BindingType::kInput, "The model variable %s is an input, but has no associated resources to bind.", name.c_str()); + + WINML_THROW_IF_FAILED(spSession->GetEngine()->CreateNullValue(value.put())); } // Hold onto the input output providers so that our memory doesnt get destroyed! auto providerInfo = ProviderInfo{inspectable, spLotusValueProvider, context}; CacheProvider(name, providerInfo); - - return std::make_tuple(name, value.release(), bindingType, ort_allocator.release()); + + return std::make_tuple(name, value, bindingType); } void LearningModelBinding::Bind( @@ -157,26 +153,17 @@ void LearningModelBinding::Bind( Windows::Foundation::Collections::IPropertySet const& properties) try { _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType bindingType; - std::string bindingName; - OrtValue* binding_value = nullptr; - OrtAllocator* ort_allocator = nullptr; + BindingType binding_type; + std::string binding_name; + winrt::com_ptr binding_value = nullptr; auto featureName = WinML::Strings::UTF8FromHString(name); - std::tie(bindingName, binding_value, bindingType, ort_allocator) = CreateBinding(featureName, value, properties); - Ort::Value ortValue = binding_value ? Ort::Value(binding_value) : Ort::Value(nullptr); - Ort::Allocator ortAllocator(adapter_.get(), ort_allocator); - switch (bindingType) { + std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, value, properties); + switch (binding_type) { case BindingType::kInput: - WINML_THROW_IF_FAILED(BindInput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; case BindingType::kOutput: - WINML_THROW_IF_FAILED(BindOutput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: FAIL_FAST(); @@ -191,8 +178,6 @@ void LearningModelBinding::Clear() try { outputs_.clear(); output_names_.clear(); m_providers.clear(); - input_allocators_.clear(); - output_allocators_.clear(); } WINML_CATCH_ALL @@ -208,14 +193,14 @@ Windows::Foundation::Collections::IIterator } Windows::Foundation::IInspectable LearningModelBinding::Lookup(hstring const& key) { - auto utf8Name = WinML::Strings::UTF8FromHString(key); + auto utf8_name = WinML::Strings::UTF8FromHString(key); - auto foundIt = m_providers.find(utf8Name); + auto foundIt = m_providers.find(utf8_name); WINML_THROW_HR_IF_FALSE_MSG( E_BOUNDS, foundIt != std::end(m_providers), "The binding collection does not contain a variable with name %s.", - utf8Name.c_str()); + utf8_name.c_str()); auto providerInfo = foundIt->second; return providerInfo.CallerSpecifiedFeatureValue; @@ -226,8 +211,8 @@ uint32_t LearningModelBinding::Size() { } bool LearningModelBinding::HasKey(hstring const& key) { - auto utf8Name = WinML::Strings::UTF8FromHString(key); - return m_providers.find(utf8Name) != m_providers.end(); + auto utf8_name = WinML::Strings::UTF8FromHString(key); + return m_providers.find(utf8_name) != m_providers.end(); } void LearningModelBinding::Split( @@ -239,169 +224,110 @@ void LearningModelBinding::Split( second = nullptr; } -ONNXTensorElementDataType STDMETHODCALLTYPE GetONNXTensorElementDataType(winml::TensorKind kind) { - if (kind == TensorKind::Float) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (kind == TensorKind::Double) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (kind == TensorKind::String) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (kind == TensorKind::UInt8) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - } else if (kind == TensorKind::Int8) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; - } else if (kind == TensorKind::UInt16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; - } else if (kind == TensorKind::Int16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; - } else if (kind == TensorKind::UInt32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; - } else if (kind == TensorKind::Int32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - } else if (kind == TensorKind::UInt64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; - } else if (kind == TensorKind::Int64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (kind == TensorKind::Boolean) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; - } else if (kind == TensorKind::Float16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - } - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; -} - -bool LearningModelBinding::IsOfMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind) { - if (ort_value.GetTypeInfo().GetONNXType() != ONNX_TYPE_MAP) - return false; - - ONNXTensorElementDataType onnx_key_type; - ONNXTensorElementDataType onnx_value_type; - - WINML_THROW_IF_FAILED(adapter_->GetMapType(ort_value, &onnx_key_type, &onnx_value_type)); - - if (onnx_key_type != GetONNXTensorElementDataType(key_kind)) - return false; - - if (onnx_value_type != GetONNXTensorElementDataType(value_kind)) - return false; - - return true; -}; - -bool LearningModelBinding::IsOfVectorMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind) { - if (ort_value.GetTypeInfo().GetONNXType() != ONNX_TYPE_SEQUENCE) - return false; - - ONNXTensorElementDataType onnx_key_type; - ONNXTensorElementDataType onnx_value_type; - - WINML_THROW_IF_FAILED(adapter_->GetVectorMapType(ort_value, &onnx_key_type, &onnx_value_type)); - - if (onnx_key_type != GetONNXTensorElementDataType(key_kind)) - return false; - - if (onnx_value_type != GetONNXTensorElementDataType(value_kind)) - return false; - - return true; -}; - -bool LearningModelBinding::IsOfTensorType(const Ort::Value& ort_value, TensorKind kind) { - return ort_value.GetTensorTypeAndShapeInfo().GetElementType() == GetONNXTensorElementDataType(kind); -}; - ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue( - const Ort::Value& ort_value, + const winrt::com_ptr value, ILearningModelFeatureDescriptor& descriptor) { - if (ort_value.IsTensor()) { - if (IsOfTensorType(ort_value, TensorKind::Float)) { + bool out; + if (SUCCEEDED(value->IsTensor(&out)) && out) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float, &out)) && out) { if (descriptor.Kind() == LearningModelFeatureKind::Image) { using namespace Windows::Graphics::Imaging; // TODO: this format for unbound output needs more discussion BitmapPixelFormat format = descriptor.as()->BitmapPixelFormat(); - uint32_t width = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[3]); - uint32_t height = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[2]); - uint32_t batchSize = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[0]); + std::vector shape; + value->GetTensorShape(shape); + uint32_t width = static_cast(shape[3]); + uint32_t height = static_cast(shape[2]); + uint32_t batchSize = static_cast(shape[0]); return implementation::ImageFeatureValue::Create(batchSize, format, width, height); } else { return implementation::TensorFloat::Create(); } } - if (IsOfTensorType(ort_value, TensorKind::Double)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Double, &out)) && out) { return implementation::TensorDouble::Create(); } - if (IsOfTensorType(ort_value, TensorKind::String)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::String, &out)) && out) { return implementation::TensorString::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt8)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt8, &out)) && out) { return implementation::TensorUInt8Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int8)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int8, &out)) && out) { return implementation::TensorInt8Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt16, &out)) && out) { return implementation::TensorUInt16Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int16, &out)) && out) { return implementation::TensorInt16Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt32)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt32, &out)) && out) { return implementation::TensorUInt32Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int32)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int32, &out)) && out) { return implementation::TensorInt32Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt64)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt64, &out)) && out) { return implementation::TensorUInt64Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int64)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int64, &out)) && out) { return implementation::TensorInt64Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Boolean)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Boolean, &out)) && out) { return implementation::TensorBoolean::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Float16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float16, &out)) && out) { return implementation::TensorFloat16Bit::Create(); } } + // Maps - else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::String)) { + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::String, &out)) && out) { return implementation::MapStringToString::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Int64)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Int64, &out)) && out) { return implementation::MapStringToInt64Bit::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Float, &out)) && out) { return implementation::MapStringToFloat::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Double)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Double, &out)) && out) { return implementation::MapStringToDouble::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::String)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::String, &out)) && out) { return implementation::MapInt64BitToString::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Int64)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Int64, &out)) && out) { return implementation::MapInt64BitToInt64Bit::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) { return implementation::MapInt64BitToFloat::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Double)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Double, &out)) && out) { return implementation::MapInt64BitToDouble::Create(); } // Sequences - else if (IsOfVectorMapType(ort_value, TensorKind::String, TensorKind::Float)) { + if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::String, TensorKind::Float, &out)) && out) { return implementation::SequenceMapStringFloat::Create(); - } else if (IsOfVectorMapType(ort_value, TensorKind::Int64, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) { return implementation::SequenceMapInt64BitFloat::Create(); } - auto utf8Name = WinML::Strings::UTF8FromHString(descriptor.Name()); + auto utf8_name = WinML::Strings::UTF8FromHString(descriptor.Name()); WINML_THROW_HR_IF_TRUE_MSG( E_UNEXPECTED, true, "The engine produced an unexpected evaluation output for unbound output variable %s.", - utf8Name.c_str()); + utf8_name.c_str()); return nullptr; } Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( const std::string& name, - Ort::Value& ort_value) { + winrt::com_ptr value) { // Find valid binding port auto bindingPort = FindValidBinding( m_session.Model(), @@ -432,12 +358,12 @@ Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( }; // Create empty feature value - auto featureValue = CreateUnboundOuputFeatureValue(ort_value, descriptor); + auto featureValue = CreateUnboundOuputFeatureValue(value, descriptor); // Update feature value auto spLotusValueProvider = featureValue.as(); WINML_THROW_IF_FAILED_MSG( - spLotusValueProvider->UpdateSourceResourceData(context, ort_value), + spLotusValueProvider->UpdateSourceResourceData(context, value.get()), "Failed to update bound object for model variable output %s", name.c_str()); @@ -454,33 +380,30 @@ Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( std::unordered_map LearningModelBinding::UpdateProviders() { std::unordered_map outputs; - auto& outputNames = GetOutputNames(); - auto& outputMLValues = GetOutputs(); + auto& output_names = GetOutputNames(); + auto& output_values = GetOutputs(); WINML_THROW_HR_IF_FALSE_MSG( E_UNEXPECTED, - outputNames.size() == outputMLValues.size(), + output_names.size() == output_values.size(), "Evaluation produced unexpected output variables."); - for (unsigned i = 0; i < outputNames.size(); i++) { - auto utf8Name = outputNames[i]; - OrtValue* mlValue = outputMLValues[i]; + for (unsigned i = 0; i < output_names.size(); i++) { + auto utf8_name = output_names[i]; + auto value = output_values[i]; - if (m_providers.find(utf8Name) != std::end(m_providers)) { - auto& providerInfo = m_providers[utf8Name]; + if (m_providers.find(utf8_name) != std::end(m_providers)) { + auto& providerInfo = m_providers[utf8_name]; auto provider = providerInfo.Provider; auto context = providerInfo.Context; WINML_THROW_IF_FAILED_MSG( - provider->UpdateSourceResourceData(context, mlValue), + provider->UpdateSourceResourceData(context, value.get()), "Failed to update bound object for model variable output %s", - utf8Name.c_str()); + utf8_name.c_str()); - outputs[utf8Name] = providerInfo.CallerSpecifiedFeatureValue; + outputs[utf8_name] = providerInfo.CallerSpecifiedFeatureValue; } else { // unbound outputs - Ort::Value ort_value(mlValue); - outputs[utf8Name] = CreateUnboundOutput(utf8Name, ort_value); - // this was a weak ref, don't let it deref() - ort_value.release(); + outputs[utf8_name] = CreateUnboundOutput(utf8_name, value); } } @@ -501,31 +424,23 @@ STDMETHODIMP LearningModelBinding::Bind( IUnknown* value) { try { _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType bindingType; - std::string bindingName; - OrtValue* binding_value_ptr = nullptr; - OrtAllocator* ort_allocator = nullptr; + BindingType binding_type; + std::string binding_name; + winrt::com_ptr binding_value; + winrt::Windows::Foundation::IInspectable to; RETURN_IF_FAILED(value->QueryInterface( winrt::guid_of(), reinterpret_cast(winrt::put_abi(to)))); auto featureName = WinML::Strings::UTF8FromUnicode(name, cchName); - std::tie(bindingName, binding_value_ptr, bindingType, ort_allocator) = CreateBinding(featureName, to, nullptr); - Ort::Value ortValue = binding_value_ptr ? Ort::Value(binding_value_ptr) : Ort::Value(nullptr); - Ort::Allocator ortAllocator(adapter_.get(), ort_allocator); - switch (bindingType) { + std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, to, nullptr); + switch (binding_type) { case BindingType::kInput: - WINML_THROW_IF_FAILED(BindInput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; case BindingType::kOutput: - WINML_THROW_IF_FAILED(BindOutput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: FAIL_FAST(); @@ -544,43 +459,37 @@ static std::pair Contains(const std::vector& names, c } // This method releases control of memory of ml_value from caller of BindInput -HRESULT LearningModelBinding::BindInput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator) { - auto rc = Contains(input_names_, name); +HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr value) { + bool exists; + size_t index; + std::tie(exists, index) = Contains(input_names_, name); - auto add_or_replace = [this, &name](const bool exists, size_t index, Ort::Value&& value, Ort::Allocator&& ort_allocator) { - if (exists) { - inputs_[index] = std::move(value); - input_allocators_[index] = std::move(ort_allocator); - } else { - input_names_.push_back(name); - inputs_.push_back(std::move(value)); - input_allocators_.push_back(std::move(ort_allocator)); - } - }; - if (ml_value.IsTensor()) { - OrtValue* new_mlvalue; - WINML_THROW_IF_FAILED(m_session.as() - ->GetIInferenceSession() - ->CopyOneInputAcrossDevices(name.c_str(), ml_value, &new_mlvalue)); - add_or_replace(rc.first, rc.second, Ort::Value(new_mlvalue), std::move(ort_allocator)); + auto engine = m_session.as()->GetEngine(); + winrt::com_ptr device_value; + WINML_THROW_IF_FAILED(engine->CreateOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch + + if (exists) { + inputs_[index] = device_value; } else { - add_or_replace(rc.first, rc.second, Ort::Value(ml_value.release()), std::move(ort_allocator)); + input_names_.push_back(name); + inputs_.push_back(device_value); } + return S_OK; } -// This method releases control of memory of ml_value from caller of BindInput -HRESULT LearningModelBinding::BindOutput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator) { - auto rc = Contains(output_names_, name); - if (rc.first) { - outputs_[rc.second] = std::move(ml_value); - output_allocators_[rc.second] = std::move(ort_allocator); +HRESULT LearningModelBinding::BindOutput(const std::string& name, winrt::com_ptr value) { + bool exists; + size_t index; + std::tie(exists, index) = Contains(output_names_, name); + + if (exists) { + outputs_[index] = value; return S_OK; } output_names_.push_back(name); - outputs_.push_back(std::move(ml_value)); - output_allocators_.push_back(std::move(ort_allocator)); + outputs_.push_back(value); return S_OK; } @@ -588,13 +497,17 @@ const std::vector& LearningModelBinding::GetOutputNames() const { return output_names_; } -std::vector& LearningModelBinding::GetOutputs() { return outputs_; } - const std::vector& LearningModelBinding::GetInputNames() const { return input_names_; } -const std::vector& LearningModelBinding::GetInputs() const { return inputs_; } +std::vector>& LearningModelBinding::GetOutputs() { + return outputs_; +} + +const std::vector>& LearningModelBinding::GetInputs() const { + return inputs_; +} void LearningModelBinding::BindUnboundOutputs() { auto& bound_output_names = GetOutputNames(); @@ -634,7 +547,11 @@ void LearningModelBinding::BindUnboundOutputs() { // Add all unbound outputs to binding collection for (const auto& unbound_output : unbound_output_names) { - WINML_THROW_IF_FAILED(BindOutput(unbound_output, Ort::Value(nullptr), Ort::Allocator())); + auto engine = m_session.as()->GetEngine(); + + winrt::com_ptr value; + WINML_THROW_IF_FAILED(engine->CreateNullValue(value.put())); + WINML_THROW_IF_FAILED(BindOutput(unbound_output, value)); } } diff --git a/winml/lib/Api/LearningModelBinding.h b/winml/lib/Api/LearningModelBinding.h index 0d2efc2339c6b..4dd2734bc0710 100644 --- a/winml/lib/Api/LearningModelBinding.h +++ b/winml/lib/Api/LearningModelBinding.h @@ -22,11 +22,12 @@ struct LearningModelBinding : LearningModelBindingT; LearningModelBinding() = delete; - ~LearningModelBinding(); LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); + STDMETHOD(Bind)(const wchar_t* name, UINT32 cchName, IUnknown* value); + void Clear(); Windows::Foundation::Collections::IIterator First(); Windows::Foundation::IInspectable Lookup(hstring const& key); @@ -36,7 +37,7 @@ struct LearningModelBinding : LearningModelBindingT& first, Windows::Foundation::Collections::IMapView& second); - std::tuple CreateBinding( + std::tuple, WinML::BindingType> CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& value, Windows::Foundation::Collections::IPropertySet const& properties); @@ -45,42 +46,32 @@ struct LearningModelBinding : LearningModelBindingT& LearningModelBinding::GetOutputNames() const; - std::vector& LearningModelBinding::GetOutputs(); - const std::vector& LearningModelBinding::GetInputNames() const; - const std::vector& LearningModelBinding::GetInputs() const; - HRESULT BindOutput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator); + const std::vector& GetInputNames() const; + const std::vector& GetOutputNames() const; + + const std::vector>& GetInputs() const; + std::vector>& GetOutputs(); + + HRESULT BindOutput(const std::string& name, winrt::com_ptr value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); - Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, Ort::Value& ort_value); + Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( - const Ort::Value& ort_value, + const winrt::com_ptr value, ILearningModelFeatureDescriptor& descriptor); - bool IsOfTensorType(const Ort::Value& ort_value, TensorKind kind); - bool IsOfMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); - bool IsOfVectorMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); - HRESULT BindInput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator); + HRESULT BindInput(const std::string& name, winrt::com_ptr value); private: const Windows::AI::MachineLearning::LearningModelSession m_session; std::unordered_map m_providers; - com_ptr adapter_; std::vector input_names_; - std::vector inputs_; - std::vector input_allocators_; + std::vector> inputs_; std::vector output_names_; - std::vector outputs_; - std::vector output_allocators_; + std::vector> outputs_; }; } // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 8f77242b5b4e3..013dca5b863ca 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -28,41 +28,42 @@ static const GUID WINML_PIX_EVAL_CAPTURABLE_WORK_GUID = __uuidof(guid_details::W namespace winrt::Windows::AI::MachineLearning::implementation { LearningModelSession::LearningModelSession( - winml::LearningModel const& model) try : LearningModelSession(model, - make(LearningModelDeviceKind::Default)) {} + winml::LearningModel const& model) try : LearningModelSession(model, + make(LearningModelDeviceKind::Default)) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, - deviceToRunOn, - nullptr) {} + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, + deviceToRunOn, + nullptr) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn, - winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), - device_(deviceToRunOn), - session_options_(learningModelSessionOptions) { + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn, + winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), + device_(deviceToRunOn), + session_options_(learningModelSessionOptions), + operator_registry_(nullptr, nullptr) { Initialize(); } WINML_CATCH_ALL -winmla::IModelProto* +WinML::IModel* LearningModelSession::GetOptimizedModel() { // Get the model proto auto should_close_model = - session_options_ != nullptr && - session_options_.CloseModelOnSessionCreation(); + session_options_ != nullptr && + session_options_.CloseModelOnSessionCreation(); return GetOptimizedModel(should_close_model); } -winmla::IModelProto* +WinML::IModel* LearningModelSession::GetOptimizedModel(bool should_close_model) { - com_ptr model_proto; + com_ptr model; { // Lock the model detach/copy since multiple threads can access concurrently @@ -70,77 +71,70 @@ LearningModelSession::GetOptimizedModel(bool should_close_model) { // Throw if the model has been disposed and is not capable of creating // new sessions. - auto model = model_.as(); - WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, model->IsDisposed(), + auto model_impl = model_.as(); + WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, model_impl->IsDisposed(), "The model has been disposed."); - model_proto.attach(should_close_model - ? model->DetachModelProto() - : model->CopyModelProto()); + model.attach(should_close_model + ? model_impl->DetachModel() + : model_impl->CloneModel()); } // Ensure that the model is runnable on the device - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->EnsureModelDeviceCompatibility(model_, model_proto.get(), device_.as()->GetD3DDeviceCache()->IsFloat16Supported())); - - return model_proto.detach(); + auto isFloat16Supported = device_.as()->GetD3DDeviceCache()->IsFloat16Supported(); + if (!isFloat16Supported) { + WINML_THROW_IF_FAILED(model->ModelEnsureNoFloat16()); + } + return model.detach(); } void LearningModelSession::Initialize() { // Begin recording session creation telemetry _winmlt::TelemetryEvent session_creation_event( - _winmlt::EventCategory::kSessionCreation); + _winmlt::EventCategory::kSessionCreation); // Get the optimized model proto from the learning model - com_ptr model_proto; - model_proto.attach(GetOptimizedModel()); + com_ptr model; + model.attach(GetOptimizedModel()); // Create the session builder auto device_impl = device_.as(); + auto model_impl = model_.as(); - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); + engine_factory_.copy_from(model_impl->GetEngineFactory()); - com_ptr session_builder; - WINML_THROW_IF_FAILED(adapter->CreateOrtSessionBuilder( - device_impl->GetD3DDevice(), - device_impl->GetDeviceQueue(), - session_builder.put())); + com_ptr engine_builder; + engine_factory_->CreateEngineBuilder(engine_builder.put()); - Ort::SessionOptions options(nullptr); - WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put())); + if (device_impl->IsCpuDevice() == false) { + engine_builder->SetD3D12Resources(device_impl->GetD3DDevice(), device_impl->GetDeviceQueue()); + } // Make onnxruntime apply the batch size override, if any - if (session_options_ && session_options_.BatchSizeOverride() != 0) - { - Ort::ThrowOnError(Ort::GetApi().AddFreeDimensionOverride( - options, - onnx::DATA_BATCH, - session_options_.BatchSizeOverride())); + if (session_options_ && session_options_.BatchSizeOverride() != 0) { + engine_builder->SetBatchSizeOverride(session_options_.BatchSizeOverride()); } - com_ptr session; - WINML_THROW_IF_FAILED(session_builder->CreateSession( - options, session.put(), &cached_execution_provider_)); + com_ptr engine; + WINML_THROW_IF_FAILED(engine_builder->CreateEngine(engine.put())); // Register the custom operator registry - auto model = model_.as(); - operatorRegistry_.reset(model->GetOperatorRegistry()); - WINML_THROW_IF_FAILED(session->RegisterCustomRegistry(operatorRegistry_.get())); + operator_registry_ = MLOperatorRegistry(model_impl->GetOperatorRegistry(), [](auto registry) { registry->Release(); }); + WINML_THROW_IF_FAILED(engine->RegisterCustomRegistry(operator_registry_.get())); - // Register only the transformers not already in ORT - session->RegisterGraphTransformers(); + // Register transformers - this should probably not be exposed on IEngine, but an internal call as this configuration step is ort specific. + engine->RegisterGraphTransformers(); // Load the model into the session - WINML_THROW_IF_FAILED(session->LoadModel(model_proto.get())); + WINML_THROW_IF_FAILED(engine->LoadModel(model.get())); + // the session owns the model_proto now, it used detach() - model_proto = nullptr; + model = nullptr; // Initialize the session - WINML_THROW_IF_FAILED(session_builder->Initialize(session.get(), cached_execution_provider_)); + WINML_THROW_IF_FAILED(engine->Initialize()); // Cache the constructed session - inference_session_ = session; + engine_ = engine; } wfc::IPropertySet @@ -165,8 +159,8 @@ LearningModelSession::Device() try { WINML_CATCH_ALL auto CreateBinding( - LearningModelSession& session, - wfc::IMap const features) { + LearningModelSession& session, + wfc::IMap const features) { auto binding = winrt::make(session); for (auto feature : features.GetView()) { @@ -177,8 +171,8 @@ auto CreateBinding( winml::LearningModelEvaluationResult LearningModelSession::EvaluateFeatures( - wfc::IMap const features, - hstring const correlation_id) try { + wfc::IMap const features, + hstring const correlation_id) try { auto binding = CreateBinding(*this, features); return Evaluate(binding, correlation_id); } @@ -186,65 +180,63 @@ WINML_CATCH_ALL wf::IAsyncOperation LearningModelSession::EvaluateFeaturesAsync( - wfc::IMap const features, - hstring const correlation_id) { + wfc::IMap const features, + hstring const correlation_id) { auto binding = CreateBinding(*this, features); return EvaluateAsync(binding, correlation_id); } -// copied from onnxruntime_cxx_inline.h -inline OrtStatus* OrtRun( - OrtSession * session, - const Ort::RunOptions& run_options, - const char* const* input_names, - const Ort::Value* input_values, - size_t input_count, - const char* const* output_names, - Ort::Value* output_values, - size_t output_count) { - static_assert(sizeof(Ort::Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(const_cast(input_values)); - auto ort_output_values = reinterpret_cast(output_values); - return Ort::GetApi().Run(session, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values); -} - -uint64_t -LearningModelSession::Run( - winrt::com_ptr binding_impl) { +uint64_t LearningModelSession::Run(winrt::com_ptr binding_impl) { CheckClosed(); + auto device = device_.as(); CWinMLAutoLock lock(!device->IsCpuDevice() ? &evaluate_lock_ : nullptr); - // TODO : set the run_options - Ort::RunOptions run_options; + binding_impl->BindUnboundOutputs(); - std::vector inputNames_c; - for (int i=0; i < binding_impl->GetInputNames().size(); i++) - { - inputNames_c.push_back(binding_impl->GetInputNames()[i].c_str()); - } - std::vector outputNames_c; - for (int i = 0; i < binding_impl->GetOutputNames().size(); i++) { - outputNames_c.push_back(binding_impl->GetOutputNames()[i].c_str()); - } - OrtSession* session = nullptr; - - WINML_THROW_IF_FAILED(inference_session_->GetOrtSession(&session)); - // Invoke run on the ORT session. - Ort::ThrowOnError(OrtRun( - session, - run_options, - inputNames_c.data(), - binding_impl->GetInputs().data(), - binding_impl->GetInputs().size(), - outputNames_c.data(), - binding_impl->GetOutputs().data(), - binding_impl->GetOutputs().size())); + auto& input_names = binding_impl->GetInputNames(); + std::vector input_names_raw; + std::transform( + std::begin(input_names), + std::end(input_names), + std::back_inserter(input_names_raw), + [&](auto& name) { return name.c_str(); }); + + auto& inputs = binding_impl->GetInputs(); + std::vector inputs_raw; + std::transform( + std::begin(inputs), + std::end(inputs), + std::back_inserter(inputs_raw), + [&](auto& input) { return input.get(); }); + + auto& output_names = binding_impl->GetOutputNames(); + std::vector output_names_raw; + std::transform( + std::begin(output_names), + std::end(output_names), + std::back_inserter(output_names_raw), + [&](auto& name) { return name.c_str(); }); + + auto outputs = binding_impl->GetOutputs(); + std::vector outputs_raw; + std::transform( + std::begin(outputs), + std::end(outputs), + std::back_inserter(outputs_raw), + [&](auto& input) { return input.get(); }); + + engine_->Run(input_names_raw.data(), + inputs_raw.data(), + input_names_raw.size(), + output_names_raw.data(), + outputs_raw.data(), + output_names_raw.size()); if (!device->IsCpuDevice()) { // Flush the D3D12 work from the DML execution provider and queue a fence before we release the lock. // This allows us to wait without holding onto the lock in GetResults. - inference_session_->FlushContext(GetExecutionProvider()); + engine_->FlushContext(); return device->GetD3DDeviceCache()->QueueFenceToD3D12(); } @@ -254,9 +246,9 @@ LearningModelSession::Run( winml::LearningModelEvaluationResult LearningModelSession::GetResults( - winrt::com_ptr binding_impl, - hstring const& correlation_id, - uint64_t evaluation_complete_fence) { + winrt::com_ptr binding_impl, + hstring const& correlation_id, + uint64_t evaluation_complete_fence) { // First wait on the fence value for the expected frame. This is passed in so that // the fence value is added to the queue in a thread safe manor. auto device = device_.as(); @@ -271,10 +263,10 @@ LearningModelSession::GetResults( if (is_gpu_evaluation) { // For DML we aren't using the Sync function because we want to make fencing the // completed frame thread safe while not holding the lock while waiting for the gpu. - inference_session_->ReleaseCompletedReferences(GetExecutionProvider()); + engine_->ReleaseCompletedReferences(); } else { // For CPU call the standard Sync function - GetExecutionProvider()->Sync(); + engine_->Sync(); } // This isn't the best we are holding the lock while we wait for detensorize on the GPU. @@ -286,7 +278,7 @@ LearningModelSession::GetResults( // to avoid requiring the extra allocation during each evaluation. if (is_first_evaluate_) { if (is_gpu_evaluation) { - inference_session_->TrimUploadHeap(GetExecutionProvider()); + engine_->TrimUploadHeap(); } is_first_evaluate_ = false; } @@ -309,7 +301,7 @@ LearningModelSession::EvaluateAsync( _winmlt::TelemetryEvent kEvaluateModel_event(_winmlt::EventCategory::kEvaluation); auto device = device_.as(); - // Get the ORT binding collection + // Get the binding collection auto binding_impl = binding.as(); ApplyEvaluationProperties(); @@ -369,7 +361,7 @@ LearningModelSession::Evaluate( capture_interface->BeginCapturableWork(WINML_PIX_EVAL_CAPTURABLE_WORK_GUID); } - // Get the ORT binding collection + // Get the binding collection auto binding_impl = binding.as(); uint64_t evaluation_complete_fence = Run(binding_impl); @@ -383,16 +375,14 @@ LearningModelSession::Evaluate( WINML_CATCH_ALL void LearningModelSession::Close() { - inference_session_ = nullptr; + engine_ = nullptr; } void LearningModelSession::ApplyEvaluationProperties() try { if (evaluation_properties_) { auto is_debug_output_enabled = evaluation_properties_.HasKey(c_enable_debug_output); if (is_debug_output_enabled) { - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - adapter->EnableDebugOutput(); + engine_factory_->EnableDebugOutput(is_debug_output_enabled); } } } @@ -407,24 +397,19 @@ void LearningModelSession::ToggleProfiler() { WINML_PROVIDER_KEYWORD_LOTUS_PROFILING); if (is_provider_enabled) { - inference_session_->StartProfiling(); + engine_->StartProfiling(); } else { - inference_session_->EndProfiling(); + engine_->EndProfiling(); } } -onnxruntime::IExecutionProvider* -LearningModelSession::GetExecutionProvider() { - return cached_execution_provider_; -} - -winmla::IInferenceSession* -LearningModelSession::GetIInferenceSession() { - return inference_session_.get(); +WinML::IEngine* +LearningModelSession::GetEngine() { + return engine_.get(); } void LearningModelSession::CheckClosed() { - if (!inference_session_) { + if (!engine_) { WINML_THROW_HR(RO_E_CLOSED); } } diff --git a/winml/lib/Api/LearningModelSession.h b/winml/lib/Api/LearningModelSession.h index 8c0acf51171cc..bdb1dd2fb0d03 100644 --- a/winml/lib/Api/LearningModelSession.h +++ b/winml/lib/Api/LearningModelSession.h @@ -9,6 +9,7 @@ #include "MLOperatorAuthor.h" #include "WinML_Lock.h" #include "core/providers/winml/winml_provider_factory.h" +#include "iengine.h" namespace winrt::Windows::AI::MachineLearning::implementation { @@ -66,11 +67,9 @@ struct LearningModelSession : LearningModelSessionT { public: /* Non-ABI methods */ - onnxruntime::IExecutionProvider* - GetExecutionProvider(); - winmla::IInferenceSession* - GetIInferenceSession(); + WinML::IEngine* + GetEngine(); void CheckClosed(); @@ -79,10 +78,10 @@ struct LearningModelSession : LearningModelSessionT { void Initialize(); - winmla::IModelProto* + WinML::IModel* GetOptimizedModel(); - winmla::IModelProto* + WinML::IModel* GetOptimizedModel(bool should_close_model); uint64_t @@ -102,16 +101,11 @@ struct LearningModelSession : LearningModelSessionT { ToggleProfiler(); private: - com_ptr inference_session_; - struct IMLOperatorRegistryDeleter { - void operator()(IMLOperatorRegistry* p) { - p->Release(); - } - }; - std::unique_ptr operatorRegistry_; - - // reference to the active execution provider. weak - onnxruntime::IExecutionProvider* cached_execution_provider_ = nullptr; + com_ptr engine_factory_; + com_ptr engine_; + + using MLOperatorRegistry = std::unique_ptr; + MLOperatorRegistry operator_registry_; winml::LearningModel model_; winml::LearningModelDevice device_; diff --git a/winml/lib/Api/MapFeatureDescriptor.cpp b/winml/lib/Api/MapFeatureDescriptor.cpp index d30734c3be065..60f63c13f85d1 100644 --- a/winml/lib/Api/MapFeatureDescriptor.cpp +++ b/winml/lib/Api/MapFeatureDescriptor.cpp @@ -18,19 +18,6 @@ MapFeatureDescriptor::MapFeatureDescriptor( value_kind_(value_kind) { } -MapFeatureDescriptor::MapFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& KeyKind, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ValueDescriptor) : - name_(Name), - description_(Description), - is_required_(IsRequired), - key_kind_(KeyKind), - value_kind_(ValueDescriptor) { -} - winml::TensorKind MapFeatureDescriptor::KeyKind() try { return key_kind_; diff --git a/winml/lib/Api/MapFeatureDescriptor.h b/winml/lib/Api/MapFeatureDescriptor.h index 1b752b2eb3b63..3641585dd7d87 100644 --- a/winml/lib/Api/MapFeatureDescriptor.h +++ b/winml/lib/Api/MapFeatureDescriptor.h @@ -17,14 +17,7 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< bool is_required, winml::TensorKind keyKind, winml::ILearningModelFeatureDescriptor valueKind); - - MapFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& KeyKind, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ValueDescriptor); - + // IMapDescriptor winml::TensorKind KeyKind(); @@ -62,10 +55,4 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< winml::TensorKind key_kind_; winml::ILearningModelFeatureDescriptor value_kind_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct MapFeatureDescriptor : MapFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/SequenceFeatureDescriptor.cpp b/winml/lib/Api/SequenceFeatureDescriptor.cpp index 725a66bae253b..0cc1248cc88eb 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.cpp +++ b/winml/lib/Api/SequenceFeatureDescriptor.cpp @@ -16,16 +16,6 @@ SequenceFeatureDescriptor::SequenceFeatureDescriptor( description_(WinML::Strings::HStringFromUTF8(description)), is_required_(is_required), element_descriptor_(descriptor) {} -SequenceFeatureDescriptor::SequenceFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ElementDescriptor) : - name_(Name), - description_(Description), - is_required_(IsRequired), - element_descriptor_(ElementDescriptor) { -} winml::ILearningModelFeatureDescriptor diff --git a/winml/lib/Api/SequenceFeatureDescriptor.h b/winml/lib/Api/SequenceFeatureDescriptor.h index 04e5d392ae261..c45a06ccaba38 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.h +++ b/winml/lib/Api/SequenceFeatureDescriptor.h @@ -15,11 +15,6 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< const char* description, bool is_required, winml::ILearningModelFeatureDescriptor element_descriptor); - SequenceFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ElementDescriptor); winml::ILearningModelFeatureDescriptor ElementDescriptor(); @@ -53,10 +48,4 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< bool is_required_; winml::ILearningModelFeatureDescriptor element_descriptor_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/TensorFeatureDescriptor.cpp b/winml/lib/Api/TensorFeatureDescriptor.cpp index e4517f7b3870b..3cf7cc6a36fd9 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.cpp +++ b/winml/lib/Api/TensorFeatureDescriptor.cpp @@ -11,9 +11,9 @@ namespace winrt::Windows::AI::MachineLearning::implementation { TensorFeatureDescriptor::TensorFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, bool has_unsupported_image_metadata) : name_(WinML::Strings::HStringFromUTF8(name)), description_(WinML::Strings::HStringFromUTF8(description)), tensor_kind_(tensor_kind), @@ -22,20 +22,6 @@ TensorFeatureDescriptor::TensorFeatureDescriptor( has_unsupported_image_metadata_(has_unsupported_image_metadata) { } -TensorFeatureDescriptor::TensorFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - winml::TensorKind const& TensorKind, - array_view Shape, - bool HasUnsupportedImageMetadata) : name_(Name), - description_(Description), - tensor_kind_(TensorKind), - shape_(Shape.begin(), Shape.end()), - is_required_(IsRequired), - has_unsupported_image_metadata_(HasUnsupportedImageMetadata) { -} - winml::TensorKind TensorFeatureDescriptor::TensorKind() try { return tensor_kind_; @@ -75,11 +61,6 @@ bool TensorFeatureDescriptor::IsRequired() try { } WINML_CATCH_ALL -bool TensorFeatureDescriptor::HasUnsupportedImageMetadata() try { - return has_unsupported_image_metadata_; -} -WINML_CATCH_ALL - bool TensorFeatureDescriptor::IsUnsupportedMetaData() try { return has_unsupported_image_metadata_; } diff --git a/winml/lib/Api/TensorFeatureDescriptor.h b/winml/lib/Api/TensorFeatureDescriptor.h index 975b359c1f13b..5e54978c5847a 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.h +++ b/winml/lib/Api/TensorFeatureDescriptor.h @@ -13,23 +13,17 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< TensorFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, bool has_unsuppored_image_metadata); - TensorFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - TensorKind const& TensorKind, - array_view Shape, - bool HasUnsupportedImageMetadata); - // ITensorDescriptor - winml::TensorKind TensorKind(); - wfc::IVectorView Shape(); - bool HasUnsupportedImageMetadata(); + winml::TensorKind + TensorKind(); + + wfc::IVectorView + Shape(); // IFeatureDescriptor winrt::hstring @@ -65,10 +59,4 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< bool is_required_; bool has_unsupported_image_metadata_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct TensorFeatureDescriptor : TensorFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/impl/MapBase.h b/winml/lib/Api/impl/MapBase.h index b3490bbe1a03b..d59366140f69e 100644 --- a/winml/lib/Api/impl/MapBase.h +++ b/winml/lib/Api/impl/MapBase.h @@ -40,53 +40,9 @@ struct MapBase : winrt::implements< std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!"); - template - struct ValidLotusType { using Type = T; }; - template <> - struct ValidLotusType { using Type = std::string; }; - - using LotusKey = typename ValidLotusType::Type; - using LotusValue = typename ValidLotusType::Type; - using LotusMap = std::pair, std::vector>; using ABIMap = ::winrt::Windows::Foundation::Collections::IMap; using ABIMapView = ::winrt::Windows::Foundation::Collections::IMapView; - template - static typename ValidLotusType::Type ConvertToValidLotusType(TRawType raw) { - return raw; - } - - template <> - static typename ValidLotusType::Type ConvertToValidLotusType(winrt::hstring raw) { - return WinML::Strings::UTF8FromHString(raw); - } - - template - static std::vector ConvertToABIType(Ort::Value& ort_value) { - // make sure this is an array of these types - auto shape = ort_value.GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - THROW_HR_IF(E_INVALIDARG, shape.size() != 1); - auto lotus_value = ort_value.GetTensorMutableData::Type>(); - // now go through all the entries - std::vector out; - for (auto i = 0; i < shape[0]; i++) { - out.push_back(lotus_value[i]); - } - // retun the vector - return out; - } - - template <> - static std::vector ConvertToABIType(Ort::Value& ort_value) { - auto strings = ort_value.GetStrings(); - std::vector out; - for (auto i = 0; i < strings.size(); ++i) { - out.push_back(WinML::Strings::HStringFromUTF8(strings[i].c_str())); - } - return out; - } - MapBase(ABIMap const& data) : data_(data) {} static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create() { @@ -129,51 +85,16 @@ struct MapBase : winrt::implements< return S_OK; } - void ConvertToLotusMap(const ABIMap& map) { - std::vector keys; - std::vector values; - for (const auto& pair : map) { - auto key = ConvertToValidLotusType(pair.Key()); - auto value = ConvertToValidLotusType(pair.Value()); - keys.push_back(key); - values.push_back(value); - } - lotus_data_ = std::make_unique(std::make_pair(keys, values)); - } - - template - static onnxruntime::MLDataType GetLotusType(winmla::IWinMLAdapter* adapter) { - return adapter->GetMapType(TensorKindFrom::Type, TensorKindFrom::Type); - } + STDMETHOD(GetValue) + (WinML::BindingContext& context, IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); - template - static Ort::Value CreateOrtMap(TLotusKey* keys, TLotusValue* values, size_t len) { - // now create OrtValue wrappers over the buffers - auto cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector shape = {static_cast(len)}; - auto keys_ort_value = Ort::Value::CreateTensor(cpu_memory, keys, len, shape.data(), shape.size()); - auto values_ort_value = Ort::Value::CreateTensor(cpu_memory, values, len, shape.data(), shape.size()); - // make the map - return Ort::Value::CreateMap(keys_ort_value, values_ort_value); - } - - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); - ORT_UNUSED_PARAMETER(context); - // TODO: Tensorized data should be cached so multiple bindings work more efficiently - - // TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything - if (context.type == WinML::BindingType::kOutput) { - *ort_value = nullptr; - return S_OK; + if (context.type == WinML::BindingType::kInput) { + RETURN_IF_FAILED(engine->CreateMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, out)); + } else { + RETURN_IF_FAILED(engine->CreateNullValue(out)); } - - // handle inputs, create and store a copy of the map - ConvertToLotusMap(data_); - - // and make the map - *ort_value = CreateOrtMap(lotus_data_->first.data(), lotus_data_->second.data(), lotus_data_->first.size()).release(); return S_OK; } @@ -185,51 +106,23 @@ struct MapBase : winrt::implements< } STDMETHOD(UpdateSourceResourceData) - (BindingContext& context, OrtValue* ort_value) { - ORT_UNUSED_PARAMETER(context); + (BindingContext& context, IValue* value) { data_.Clear(); - - Ort::AllocatorWithDefaultOptions allocator; - - // get the keys - OrtValue* ptr = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, 0, allocator, &ptr)); - Ort::Value keys{ptr}; - // get the values - ptr = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, 1, allocator, &ptr)); - Ort::Value values{ptr}; - - auto keys_vector = ConvertToABIType(keys); - auto values_vector = ConvertToABIType(values); - - auto len = keys.GetCount(); - for (auto i = 0; i < len; ++i) { - data_.Insert(keys_vector[i], values_vector[i]); - } - return S_OK; - - // TODO: code this - //const LotusMap& map = *static_cast(pResource); - //for (const auto& pair : map) { - // auto key = ConvertToABIType(pair.first); - // auto value = ConvertToABIType(pair.second); - // data_.Insert(key, value); - //} - + auto session = context.session.as(); + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->FillFromMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, value)); return S_OK; } STDMETHOD(AbiRepresentation) ( - winrt::Windows::Foundation::IInspectable& abiRepresentation) { + winrt::Windows::Foundation::IInspectable& abiRepresentation) { data_.as(abiRepresentation); return S_OK; } private: ABIMap data_; - std::unique_ptr lotus_data_; }; } // namespace Windows::AI::MachineLearning diff --git a/winml/lib/Api/impl/SequenceBase.h b/winml/lib/Api/impl/SequenceBase.h index 560ade3faeae6..d84008ec69ed9 100644 --- a/winml/lib/Api/impl/SequenceBase.h +++ b/winml/lib/Api/impl/SequenceBase.h @@ -22,35 +22,28 @@ struct SequenceBase : public winrt::implements< winml::ILearningModelFeatureValue, WinML::ISequenceFeatureValue, WinML::ILotusValueProviderPrivate> { + using ABISequence = wfc::IIterable; using AbiMapStringToFloat = wfc::IMap; using AbiMapInt64BitToFloat = wfc::IMap; - template - struct ValidLotusType { using Type = T; }; - template <> - struct ValidLotusType { - //using Type = std::map; - using TKey = std::string; - using TValue = float; - using Type = std::pair, std::vector>; - using ABIKey = winrt::hstring; - using ABIValue = TValue; + template struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::Undefined; + static constexpr winml::TensorKind Value = winml::TensorKind::Undefined; + }; + template <> struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::String; + static constexpr winml::TensorKind Value = winml::TensorKind::Float; }; template <> - struct ValidLotusType { - //using Type = std::map; - using TKey = int64_t; - using TValue = float; - using Type = std::pair, std::vector>; - using ABIKey = TKey; - using ABIValue = TValue; + struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::Int64; + static constexpr winml::TensorKind Value = winml::TensorKind::Float; }; template void GetElementDescriptor(winml::ILearningModelFeatureDescriptor* result) { - *result = TensorFeatureDescriptorFrom::CreateAnonymous( - std::vector{1, 1, 1, 1}); + static_assert(false, "Only sequences of of map and map are supported.") } template <> @@ -87,9 +80,6 @@ struct SequenceBase : public winrt::implements< value_descriptor /* value kind */); } - using LotusSequence = std::vector::Type>; - using ABISequence = wfc::IIterable; - SequenceBase(const ABISequence& data) : data_(data) {} static winml::ILearningModelFeatureValue @@ -120,114 +110,22 @@ struct SequenceBase : public winrt::implements< return S_OK; } - template - static - typename ValidLotusType::Type - ConvertToValidLotusType( - TRawType raw) { - return raw; - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - winrt::hstring raw) { - return WinML::Strings::UTF8FromHString(raw); - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - AbiMapStringToFloat raw) { - std::vector::TKey> keys; - std::vector::TValue> values; - for (auto pair : raw) { - auto key = WinML::Strings::UTF8FromHString(pair.Key()); - keys.push_back(key); - values.push_back(pair.Value()); - } - return std::make_pair(keys, values); - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - AbiMapInt64BitToFloat raw) { - std::vector::TKey> keys; - std::vector::TValue> values; - for (const auto& pair : raw) { - keys.push_back(pair.Key()); - values.push_back(pair.Value()); - } - return std::make_pair(keys, values); - } - - void - ConvertToLotusSequence( - const ABISequence& sequence) { - LotusSequence lotus_sequence; - - std::transform( - begin(sequence), - end(sequence), - std::back_inserter(lotus_sequence), - [](const auto& value) { - return ConvertToValidLotusType(value); - }); - - lotus_data_ = std::make_unique(lotus_sequence); - } - - template - static Ort::Value CreateOrtMap(TLotusKey* keys, TLotusValue* values, size_t len) { - // now create OrtValue wrappers over the buffers - auto cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector shape = {static_cast(len)}; - auto keys_ort_value = Ort::Value::CreateTensor(cpu_memory, keys, len, shape.data(), shape.size()); - auto values_ort_value = Ort::Value::CreateTensor(cpu_memory, values, len, shape.data(), shape.size()); - // make the map - return Ort::Value::CreateMap(keys_ort_value, values_ort_value); - } - - STDMETHOD(GetOrtValue)( + STDMETHOD(GetValue)( WinML::BindingContext& context, - OrtValue** ort_value, - OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); - // TODO: Tensorized data should be cached so multiple bindings work more efficiently - - // TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything - if (context.type == WinML::BindingType::kOutput) { - *ort_value = nullptr; - return S_OK; + IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); + + if (context.type == WinML::BindingType::kInput) { + // In opset 10, all ops that use sequences are seq. + // In opset 11, we will need to support seq> as well. + RETURN_IF_FAILED(engine->CreateSequenceOfMapsValue( + reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), + SequenceAbiTypeInfo::Key, SequenceAbiTypeInfo::Value, out)); + } else { + RETURN_IF_FAILED(engine->CreateNullValue(out)); } - - // handle inputs, create and store a copy of the sequence - ConvertToLotusSequence(data_); - - // now create OrtValue wrappers over the buffers - std::vector sequence_values; - for (auto it = lotus_data_->begin(); it != lotus_data_->end(); ++it) { - // make a ort value for this map - auto map = *it; - sequence_values.emplace_back(CreateOrtMap(map.first.data(), map.second.data(), map.first.size())); - } - *ort_value = Ort::Value::CreateSequence(sequence_values).release(); return S_OK; - - /* winrt::com_ptr adapter; - RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - auto lotus_type = adapter->GetVectorMapType( - TensorKindFrom::TKey>::Type, - TensorKindFrom::TValue>::Type); - - winrt::com_ptr ml_value_out; - adapter->CreateOrtValue(lotus_data_.get(), lotus_type, ml_value_out.put()); - - *ml_value = ml_value_out.detach();*/ } STDMETHOD(IsPlaceholder) @@ -238,61 +136,16 @@ struct SequenceBase : public winrt::implements< return S_OK; } - template - static std::vector ConvertToABIType(Ort::Value& ort_value) { - // make sure this is an array of these types - auto shape = ort_value.GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - THROW_HR_IF(E_INVALIDARG, shape.size() != 1); - auto lotus_value = ort_value.GetTensorMutableData::Type>(); - // now go through all the entries - std::vector out; - for (auto i = 0; i < shape[0]; i++) { - out.push_back(lotus_value[i]); - } - // return the vector - return out; - } - - template <> - static std::vector ConvertToABIType(Ort::Value& ort_value) { - auto strings = ort_value.GetStrings(); - std::vector out; - for (auto i = 0; i < strings.size(); ++i) { - out.push_back(WinML::Strings::HStringFromUTF8(strings[i].c_str())); - } - return out; - } - STDMETHOD(UpdateSourceResourceData)( BindingContext& context, - OrtValue* ort_value) { - ORT_UNUSED_PARAMETER(context); + IValue* out) { auto writable_vector = data_.as>(); writable_vector.Clear(); - Ort::AllocatorWithDefaultOptions allocator; - size_t len; - Ort::ThrowOnError(Ort::GetApi().GetValueCount(ort_value, &len)); - for (auto i = 0; i < len; ++i) { - OrtValue* out = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, i, allocator, &out)); - Ort::Value map{out}; - auto keys = map.GetValue(0, allocator); - auto values = map.GetValue(1, allocator); + auto session = context.session.as(); + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->FillSequenceOfMapsValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), SequenceAbiTypeInfo::Key, SequenceAbiTypeInfo::Value, out)); - auto keys_vector = ConvertToABIType::ABIKey>(keys); - auto values_vector = ConvertToABIType::ABIValue>(values); - - std::map::ABIKey, typename ValidLotusType::ABIValue> std_map; - for (auto j = 0; j < keys_vector.size(); ++j) { - std_map[keys_vector[j]] = values_vector[j]; - } - auto abi_map = winrt::single_threaded_map::ABIKey, typename ValidLotusType::ABIValue>( - std::move(std_map)); - - writable_vector.Append(abi_map); - } return S_OK; } @@ -304,7 +157,6 @@ struct SequenceBase : public winrt::implements< private: ABISequence data_; - std::unique_ptr lotus_data_; }; -} // namespace Windows::AI::MachineLearning +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/Tensor.h b/winml/lib/Api/impl/Tensor.h index 69503cb42b63e..2e0e2b16ee34d 100644 --- a/winml/lib/Api/impl/Tensor.h +++ b/winml/lib/Api/impl/Tensor.h @@ -5,14 +5,6 @@ #include "TensorBuffer.h" -// we further specialize these base types for a couple of extra tensor element types -namespace Ort { -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; -} - // // the Tensor class is the actual object for CPU memory buffers. // TensorBase contains one of these to represent the raw memory @@ -27,13 +19,11 @@ class Tensor { TensorBufferPtr m_buffer; std::vector shape_; - winrt::com_ptr adapter_; public: Tensor() = delete; Tensor( - winmla::IWinMLAdapter* adapter, std::vector const& shape, winrt::Windows::Storage::Streams::IBuffer buffer) : shape_(shape), m_buffer( @@ -45,11 +35,9 @@ class Tensor { static_cast(1), std::multiplies())), buffer)) { - adapter_.copy_from(adapter); } Tensor( - winmla::IWinMLAdapter* adapter, std::vector const& shape) : shape_(shape), m_buffer( TensorBuffer::Create( @@ -59,11 +47,9 @@ class Tensor { std::end(shape), static_cast(1), std::multiplies())))) { - adapter_.copy_from(adapter); } Tensor( - winmla::IWinMLAdapter* adapter, std::vector const&& shape) : shape_(std::move(shape)), m_buffer( TensorBuffer::Create( @@ -73,31 +59,18 @@ class Tensor { std::end(shape), static_cast(1), std::multiplies())))) { - adapter_.copy_from(adapter); } auto size() const { return m_buffer->Size(); } - auto buffer() { - return m_buffer->Buffer(); + auto size_in_bytes() const { + return m_buffer->SizeInBytes(); } - Ort::Value GetValue() { - // this is cpu memory - // TODO: what is the difference between the device allocator and the arena allocator? - Ort::MemoryInfo cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - // create the OrtValue as a tensor letting ort know that we own the data buffer - auto value = Ort::Value::CreateTensor( - cpu_memory, - buffer().second, - m_buffer->SizeInBytes(), - shape_.data(), - shape_.size()); -// Ort::TypeToTensorType::type); - return value; + auto buffer() { + return m_buffer->Buffer(); } void set(uint32_t size, const T* pData) { @@ -111,5 +84,9 @@ class Tensor { const std::vector& shape() const { return shape_; } + + auto get_tensor_buffer() { + return m_buffer; + } }; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorBase.h b/winml/lib/Api/impl/TensorBase.h index c706ecbd4f914..c79197fe4b9e9 100644 --- a/winml/lib/Api/impl/TensorBase.h +++ b/winml/lib/Api/impl/TensorBase.h @@ -15,6 +15,7 @@ #include "core/session/onnxruntime_c_api.h" namespace Windows::AI::MachineLearning { + // TensorBase // // This is the base class for all data based Tensor types. It exposes array and IVectorView @@ -69,87 +70,78 @@ struct TensorBase : TBase { /// 3) use provided backing gpu memory /// a) TensorBase(std::vector const& shape, ID3D12Resource* pResource) TensorBase() : m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); } TensorBase(winrt::Windows::Foundation::Collections::IIterable const& shape) : shape_(begin(shape), end(shape)), m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } TensorBase(std::vector const& shape) : shape_(shape), m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } - TensorBase(std::vector const& shape, ID3D12Resource* pResource, UINT64 resource_width) : shape_(shape), - m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + TensorBase(std::vector const& shape, ID3D12Resource* resource) : shape_(shape), + m_resources(std::make_shared>()) { // This Api is not supported for TensorString WINML_THROW_HR_IF_TRUE_MSG( E_ILLEGAL_METHOD_CALL, (std::is_same::value), "TensorString objects cannot be created from a ID3D12Resource!"); - GetGpuResource() = std::make_shared(pResource, resource_width); + GetGpuResource().copy_from(resource); } - Ort::Value CreateGPUMLValue(std::shared_ptr& resource, BindingContext& context) { + HRESULT CreateGPUMLValue(ID3D12Resource* resource, BindingContext& context, IValue** out) { THROW_HR_IF_NULL(E_INVALIDARG, resource); - THROW_HR_IF_NULL(E_UNEXPECTED, resource->ExecutionProviderAllocatedResource); - - Ort::MemoryInfo dml_memory(nullptr); - auto session_impl = context.session.as(); - auto provider = session_impl->GetExecutionProvider(); - WINML_THROW_IF_FAILED(adapter_->GetProviderMemoryInfo(provider, dml_memory.put())); - auto spSession = context.session.as(); - auto spDevice = spSession->Device().as(); + auto session = context.session.as(); + auto device = session->Device().as(); WINML_THROW_HR_IF_TRUE_MSG(WINML_ERR_INVALID_BINDING, - spDevice->IsCpuDevice(), + device->IsCpuDevice(), "Cannot create GPU tensor on CPU device"); - // create the OrtValue as a tensor letting ort know that we own the data buffer - auto value = Ort::Value::CreateTensor( - dml_memory, - resource->ExecutionProviderAllocatedResource, - resource->resource_width_, - shape_.data(), - shape_.size(), - Ort::TypeToTensorType::type); - return value; + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->CreateTensorValueFromExternalD3DResource(resource, shape_.data(), shape_.size(), TensorKind(), out)); + return S_OK; } - Ort::Value CPUTensorize(WinML::BindingContext& context) { + HRESULT CPUTensorize(WinML::BindingContext& context, IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); + if (GetCpuResource() != nullptr) { - return GetCpuResource()->GetValue(); + return CreateTensorValueFromExternalBuffer(engine, out); } // If there is no matching cpu resource, then fallback to a gpu resource if (GetGpuResource() != nullptr) { - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } WINML_THROW_HR(WINML_ERR_INVALID_BINDING); } - Ort::Value GPUTensorize(WinML::BindingContext& context) { + HRESULT GPUTensorize(WinML::BindingContext& context, IValue** out) { if (GetGpuResource() != nullptr) { - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } + // Get engine + auto session = context.session.as(); + auto engine = session->GetEngine(); + // If there is no matching gpu resource, then fallback to a cpu resource if (GetCpuResource() != nullptr) { - return GetCpuResource()->GetValue(); + return CreateTensorValueFromExternalBuffer(engine, out); } if (TensorKind() == winrt::Windows::AI::MachineLearning::TensorKind::String) { // Lazily allocate the cpu TensorString resource // TensorStrings are CPU only, and so a gpu resource cannot be allocated for them. - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); - return GetCpuResource()->GetValue(); + GetCpuResource() = std::make_shared>(shape_); + return CreateTensorValueFromExternalBuffer(engine, out); } else { // Try to allocate the backing memory for the caller auto bufferSize = std::accumulate(std::begin(shape_), std::end(shape_), static_cast(1), std::multiplies()); @@ -178,21 +170,21 @@ struct TensorBase : TBase { D3D12_TEXTURE_LAYOUT_ROW_MAJOR, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - auto spSession = context.session.as(); - auto spDevice = spSession->Device().as(); + auto device = session->Device().as(); - winrt::com_ptr pGPUResource = nullptr; - spDevice->GetD3DDevice()->CreateCommittedResource( + winrt::com_ptr gpu_resource = nullptr; + device->GetD3DDevice()->CreateCommittedResource( &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, D3D12_RESOURCE_STATE_COMMON, nullptr, __uuidof(ID3D12Resource), - pGPUResource.put_void()); + gpu_resource.put_void()); + + GetGpuResource() = gpu_resource; - GetGpuResource() = std::make_shared(pGPUResource.get(), resourceDesc.Width); - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } } @@ -207,9 +199,8 @@ struct TensorBase : TBase { } // ILotusValueProviderPrivate::GetOrtValue - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); + STDMETHOD(GetValue) + (WinML::BindingContext& context, IValue** out) { RETURN_HR_IF_NULL_MSG( WINML_ERR_INVALID_BINDING, m_resources, @@ -219,10 +210,11 @@ struct TensorBase : TBase { auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); + if (spDevice->IsCpuDevice()) { - *ort_value = CPUTensorize(context).release(); + RETURN_IF_FAILED(CPUTensorize(context, out)); } else { - *ort_value = GPUTensorize(context).release(); + RETURN_IF_FAILED(GPUTensorize(context, out)); } return S_OK; @@ -240,47 +232,88 @@ struct TensorBase : TBase { return size; } + template + void SetBufferFromValueResourceBuffer(uint32_t size, void* data) { + // This adds compile time checks that ensure that the API can only be called when + // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. + ASSERT_TEMPLATE_PARAMETERS(); + + GetCpuResource()->set(size, reinterpret_cast(data)); + } + + template <> + void SetBufferFromValueResourceBuffer(uint32_t size, void* data) { + // Ensure that this call is being called with the correct template parameters + ASSERT_TEMPLATE_PARAMETERS(); + + GetCpuResource()->get_tensor_buffer()->Set(size, reinterpret_cast(data)); + } + + template + HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + // This adds compile time checks that ensure that the API can only be called when + // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. + ASSERT_TEMPLATE_PARAMETERS(); + + RETURN_IF_FAILED_MSG(engine->CreateTensorValueFromExternalBuffer( + GetCpuResource()->buffer().second, GetCpuResource()->size_in_bytes(), GetCpuResource()->shape().data(), + GetCpuResource()->shape().size(), TensorKind(), value), + "Failed to prepare buffer for copy back from device resource."); + return S_OK; + } + + template <> + HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + // Ensure that this call is being called with the correct template parameters + ASSERT_TEMPLATE_PARAMETERS(); + + std::vector raw_values; + auto string_array = GetCpuResource()->buffer().second; + std::transform( + string_array, + string_array + GetCpuResource()->size_in_bytes(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_IF_FAILED_MSG(engine->CreateStringTensorValueFromDataWithCopy( + raw_values.data(), raw_values.size(), GetCpuResource()->shape().data(), + GetCpuResource()->shape().size(), value), + "Failed to prepare buffer for copy back from device resource."); + return S_OK; + } // ILotusValueProviderPrivate::UpdateSourceResourceData STDMETHOD(UpdateSourceResourceData) - (BindingContext& context, OrtValue* ort_value) { + (BindingContext& context, IValue* value) { RETURN_HR_IF_NULL_MSG( E_ILLEGAL_METHOD_CALL, m_resources, "The tensor has been closed and its resources have been detached during evaluation!"); - // get the mutable raw data buffer - void* pResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_value, &pResource)); + WinML::Resource updated_resource; + RETURN_IF_FAILED(value->GetResource(updated_resource)); // get the shape - Ort::TensorTypeAndShapeInfo type_and_shape(nullptr); - Ort::ThrowOnError(Ort::GetApi().GetTensorTypeAndShape(ort_value, type_and_shape.put())); - shape_ = type_and_shape.GetShape(); + RETURN_IF_FAILED_MSG(value->GetTensorShape(shape_), "Failed to get the tensor shape from resource!"); // make sure we always have a CPU resource if (GetCpuResource() == nullptr) { - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } - // get the memory info for the ort value - Ort::MemoryInfo memory_info(nullptr); - RETURN_IF_FAILED(adapter_->GetValueMemoryInfo(ort_value, memory_info.put())); - - // is it from the CPU provider? - if (!strcmp(memory_info.Name(), onnxruntime::CPU) || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUOutput || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUInput) { + bool is_cpu; + if (SUCCEEDED(value->IsCpu(&is_cpu)) && is_cpu) { // Get the data pointer and size - T* pData; - uint32_t pSize; - std::tie(pSize, pData) = GetCpuResource()->buffer(); + T* data; + uint32_t size; + std::tie(size, data) = GetCpuResource()->buffer(); - if (pResource != reinterpret_cast(pData)) { + if (updated_resource.get() != reinterpret_cast(data)) { // Only copy the data if the source and destination are not the same! // The engine provided buffer will not match the tensor buffer when // the tensor is created as a placeholder output, or as an unbound output. - GetCpuResource()->set(static_cast(ShapeSize(shape_)), reinterpret_cast(pResource)); + auto shape_size = static_cast(ShapeSize(shape_)); + SetBufferFromValueResourceBuffer(shape_size, updated_resource.get()); } } else { // If we got a gpu resource, we should move the data to the cpu so accessors can retrieve the data. @@ -288,8 +321,12 @@ struct TensorBase : TBase { // resources for tensors. Therefore we are certain that the returned dxresource is the same as the one we passed in // and was updated in place. auto spSession = context.session.as(); - auto cpuValue = GetCpuResource()->GetValue(); - RETURN_IF_FAILED(adapter_->CopyTensor(spSession->GetExecutionProvider(), ort_value, cpuValue)); + auto engine = spSession->GetEngine(); + + winrt::com_ptr dest; + RETURN_IF_FAILED_MSG(CreateTensorValueFromExternalBuffer(engine, dest.put()), + "Failed to prepare buffer for copy back from device resource."); + RETURN_IF_FAILED(engine->CopyValueAcrossDevices(value, dest.get())); } return S_OK; @@ -377,7 +414,7 @@ struct TensorBase : TBase { typename TBase::class_type tensorValue = winrt::make(); auto tensorValueImpl = tensorValue.as(); tensorValueImpl->shape_ = vecShape; - tensorValueImpl->GetCpuResource() = std::make_shared>(tensorValueImpl->adapter_.get(), vecShape, buffer); + tensorValueImpl->GetCpuResource() = std::make_shared>(vecShape, buffer); return tensorValue; } WINML_CATCH_ALL @@ -410,7 +447,7 @@ struct TensorBase : TBase { THROW_HR_IF(E_INVALIDARG, desc.Width < width); // make the underlying winrt object - typename TBase::class_type tensorValue = winrt::make(shapeVector, value, desc.Width); + typename TBase::class_type tensorValue = winrt::make(shapeVector, value); // return it (the caller owns the ref) *result = tensorValue.as().detach(); @@ -496,7 +533,7 @@ struct TensorBase : TBase { // This Api is not supported for TensorString RETURN_HR_IF_MSG( ERROR_INVALID_FUNCTION, - (std::is_same::value), + (std::is_same_v), "TensorString objects cannot return byte buffers!"); RETURN_HR_IF_NULL_MSG( @@ -518,7 +555,7 @@ struct TensorBase : TBase { m_resources, "The tensor has been closed and its resources have been detached!"); - GetGpuResource()->DXResource.copy_to(ppResource); + GetGpuResource().copy_to(ppResource); return S_OK; } WINML_CATCH_ALL_COM @@ -551,12 +588,12 @@ struct TensorBase : TBase { // Specialized version to convert float16 to float template <> - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -567,7 +604,7 @@ struct TensorBase : TBase { floatValue.data(), sizeof(float) /* output stride */, reinterpret_cast(pBuffer), - sizeof(DirectX::PackedVector::HALF) /* input stride */, + sizeof(WinML::Half) /* input stride */, size); // Create IVectorView from copied data. @@ -684,12 +721,12 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromArray(winrt::array_view data) { + void SetBufferFromArray(winrt::array_view data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -697,7 +734,7 @@ struct TensorBase : TBase { THROW_HR_IF(E_UNEXPECTED, data.size() != size); DirectX::PackedVector::XMConvertFloatToHalfStream( reinterpret_cast(pBuffer), - sizeof(DirectX::PackedVector::HALF) /* output stride */, + sizeof(WinML::Half) /* output stride */, data.data(), sizeof(float) /* input stride */, data.size()); @@ -760,13 +797,13 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromIterable( + void SetBufferFromIterable( winrt::Windows::Foundation::Collections::IIterable const& data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -826,7 +863,7 @@ struct TensorBase : TBase { return m_resources->CpuResource; } - std::shared_ptr& GetGpuResource() { + winrt::com_ptr& GetGpuResource() { WINML_THROW_HR_IF_NULL_MSG( E_ILLEGAL_METHOD_CALL, m_resources, @@ -840,7 +877,6 @@ struct TensorBase : TBase { std::shared_ptr> m_resources; std::vector>> m_outstandingReferences; bool m_isClosed = false; - winrt::com_ptr adapter_; }; } // namespace Windows::AI::MachineLearning diff --git a/winml/lib/Api/impl/TensorBuffer.h b/winml/lib/Api/impl/TensorBuffer.h index 079175fca2a27..d43b61d7cb25a 100644 --- a/winml/lib/Api/impl/TensorBuffer.h +++ b/winml/lib/Api/impl/TensorBuffer.h @@ -133,9 +133,7 @@ class TensorBuffer { return std::make_pair(gsl::narrow_cast(m_buffer.size()), m_buffer.data()); } - // The Set APIs should generally be avoided implemented in the TensorBuffer. - // Callers should generally use the Buffer API and copy directly into it. - auto Set(uint32_t size, const std::string* pData) { + auto Set(uint32_t size, std::string_view* data) { WINML_THROW_HR_IF_FALSE_MSG( E_INVALIDARG, size <= m_buffer.size(), @@ -143,24 +141,8 @@ class TensorBuffer { static_cast(size), static_cast(m_buffer.size())); - std::copy(pData, pData + size, m_buffer.begin()); - } - - auto Set(std::vector&& other) { - auto tensorSize = m_buffer.size(); - - WINML_THROW_HR_IF_FALSE_MSG( - E_INVALIDARG, - other.size() <= tensorSize, - "Vector argument other has size (%d) which is greater than tensor size(%d)", - static_cast(other.size()), - static_cast(tensorSize)); - - if (tensorSize != other.size()) { - other.resize(tensorSize); - } - - m_buffer = std::move(other); + // Copy + std::copy(data, data + size, m_buffer.begin()); } }; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorKindFrom.h b/winml/lib/Api/impl/TensorKindFrom.h index 0a3c8a6a7218e..e9662727f53eb 100644 --- a/winml/lib/Api/impl/TensorKindFrom.h +++ b/winml/lib/Api/impl/TensorKindFrom.h @@ -4,6 +4,13 @@ #pragma once namespace Windows::AI::MachineLearning { + +// We need to define our own type for Half since DirectX::PackedVector::Half resolves to uint16_t per its typedef declaration. +// Templates require an actual type name to resolve correctly. +struct Half { + DirectX::PackedVector::HALF value; +}; + template struct TensorKindFrom {}; template <> @@ -60,12 +67,7 @@ struct TensorKindFrom { static const winml::TensorKind Type = wi template <> struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::String; }; template <> -struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::Float16; }; - -template -struct ONNXTensorElementDataTypeFrom {}; - - +struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::Float16; }; template struct TensorFeatureDescriptorFrom { @@ -75,9 +77,9 @@ struct TensorFeatureDescriptorFrom { return winrt::make( nullptr /* set to null as values are name-less */, nullptr /* set to null as values are description-less */, - false /* set to false as values dont have required annotations */, TensorKindFrom::Type, shape, + false /* set to false as values dont have required annotations */, false /* set to false as this is not a tensor of unsupported metadata */); } }; diff --git a/winml/lib/Api/impl/TensorMemoryBufferReference.h b/winml/lib/Api/impl/TensorMemoryBufferReference.h index bc6234c4e8741..f5df6c47c68c0 100644 --- a/winml/lib/Api/impl/TensorMemoryBufferReference.h +++ b/winml/lib/Api/impl/TensorMemoryBufferReference.h @@ -9,24 +9,6 @@ #include namespace Windows::AI::MachineLearning { -struct DMLResource { - DMLResource(ID3D12Resource* pResource, UINT64 resource_width) { - DXResource.copy_from(pResource); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - ExecutionProviderAllocatedResource = adapter_->CreateGPUAllocationFromD3DResource(pResource); - resource_width_ = resource_width; - } - - ~DMLResource() { - adapter_->FreeGPUAllocation(ExecutionProviderAllocatedResource); - } - - winrt::com_ptr DXResource; - UINT64 resource_width_; - void* ExecutionProviderAllocatedResource = nullptr; - winrt::com_ptr adapter_; -}; - template struct TensorResources { // ITensorNative::GetBuffer @@ -36,40 +18,28 @@ struct TensorResources { RETURN_HR_IF_NULL(E_POINTER, value); RETURN_HR_IF_NULL(E_POINTER, capacity); - *value = nullptr; - *capacity = 0; - - // This Api is not supported for TensorString - auto isTensorString = std::is_same::value; - RETURN_HR_IF(ERROR_INVALID_FUNCTION, isTensorString); + RETURN_HR_IF_MSG( + ERROR_INVALID_FUNCTION, + (std::is_same_v), + "TensorString objects cannot return byte buffers!"); try { + *value = nullptr; + *capacity = 0; + // Lazily allocate the cpu resource on call to GetBuffer if (CpuResource == nullptr) { - winrt::com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - CpuResource = std::make_shared>(adapter.get(), shape); + CpuResource = std::make_shared>(shape); } - if constexpr (std::is_same_v) { - std::string* pData; - uint32_t pSize; - std::tie(pSize, pData) = CpuResource->buffer(); - - // Set out parameters - *capacity = static_cast(pSize * sizeof(T)); - *value = (BYTE*)pData; - } else { - // Get the data pointer and size - T* pData; - uint32_t pSize; - std::tie(pSize, pData) = CpuResource->buffer(); - - // Set out parameters - *capacity = static_cast(pSize * sizeof(T)); - *value = (BYTE*)pData; - } + // Get the data pointer and size + T* data; + uint32_t size; + std::tie(size, data) = CpuResource->buffer(); + // Set out parameters + *capacity = static_cast(size * sizeof(T)); + *value = (BYTE*)data; return S_OK; } WINML_CATCH_ALL_COM @@ -77,7 +47,7 @@ struct TensorResources { // Theses are access directly by TensorMemoryBufferReference and TensorBase std::shared_ptr> CpuResource; - std::shared_ptr GpuResource; + winrt::com_ptr GpuResource; }; // This class holds onto the lifetime of TensorResources so that they can be kept alive by TensorBase AND its active MBRs. diff --git a/winml/lib/Api/inc/ILotusValueProviderPrivate.h b/winml/lib/Api/inc/ILotusValueProviderPrivate.h index 5ae5adc902a67..3bfc6a2a79961 100644 --- a/winml/lib/Api/inc/ILotusValueProviderPrivate.h +++ b/winml/lib/Api/inc/ILotusValueProviderPrivate.h @@ -3,7 +3,7 @@ #pragma once -#include "WinMLAdapter.h" +#include "iengine.h" // ILotusValueProviderPrivate exposes a private Lotus interface to the engine so that it can retrieve tensor // resources stored in winrt structures. @@ -24,9 +24,9 @@ struct BindingContext { }; struct __declspec(uuid("27e2f437-0112-4693-849e-e04323a620fb")) __declspec(novtable) ILotusValueProviderPrivate : IUnknown { - virtual HRESULT __stdcall GetOrtValue(BindingContext& binding_context, OrtValue** ort_value, OrtAllocator** ort_allocator) = 0; + virtual HRESULT __stdcall GetValue(BindingContext& binding_context, WinML::IValue** out) = 0; virtual HRESULT __stdcall IsPlaceholder(bool* is_placeholder) = 0; - virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, OrtValue* ort_value) = 0; + virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, WinML::IValue* value) = 0; virtual HRESULT __stdcall AbiRepresentation(winrt::Windows::Foundation::IInspectable& abi_representation) = 0; }; diff --git a/winml/lib/Common/inc/PheonixSingleton.h b/winml/lib/Common/inc/PheonixSingleton.h index c3ab8edd821cc..0ab0f21f4cad5 100644 --- a/winml/lib/Common/inc/PheonixSingleton.h +++ b/winml/lib/Common/inc/PheonixSingleton.h @@ -3,8 +3,8 @@ #pragma once -template -std::shared_ptr PheonixSingleton() { +template +std::shared_ptr PheonixSingleton(TArgs&&... args) { static std::weak_ptr instance_; static std::mutex lock_; @@ -13,7 +13,7 @@ std::shared_ptr PheonixSingleton() { return instance; } - auto instance = std::make_shared(); + auto instance = std::make_shared(std::forward(args)...); instance_ = instance; return instance; } \ No newline at end of file diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h new file mode 100644 index 0000000000000..f9c9dd503dc40 --- /dev/null +++ b/winml/lib/Common/inc/iengine.h @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning { + +MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") +IModelInfo : IUnknown { + STDMETHOD(GetAuthor) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetName) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetDomain) + (const char** out, size_t* len) PURE; + + + STDMETHOD(GetDescription) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetVersion) + (int64_t * out) PURE; + + STDMETHOD(GetModelMetadata) + (ABI::Windows::Foundation::Collections::IMapView * *metadata) PURE; + + STDMETHOD(GetInputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView * *features) PURE; + + STDMETHOD(GetOutputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView * *features) PURE; +}; + +MIDL_INTERFACE("1b198b76-5c44-480d-837c-8433ca6eaf99") +IModel : IUnknown { + STDMETHOD(GetModelInfo) + (IModelInfo * *info) PURE; + + STDMETHOD(ModelEnsureNoFloat16) + () PURE; + + STDMETHOD(CloneModel) + (IModel * *copy) PURE; +}; + +using Resource = std::unique_ptr>; +MIDL_INTERFACE("31f39226-cfe8-4758-af38-3d01b2a33ee1") +IValue : IUnknown { + STDMETHOD(IsEmpty) + (bool* out) PURE; + + STDMETHOD(IsCpu) + (bool* out) PURE; + + STDMETHOD(GetResource) + (WinML::Resource & resource) PURE; + + STDMETHOD(IsTensor) + (bool* out) PURE; + + STDMETHOD(IsOfTensorType) + (winml::TensorKind kind, bool* out) PURE; + + STDMETHOD(GetTensorShape) + (std::vector & shape_vector) PURE; + + STDMETHOD(IsOfMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + + STDMETHOD(IsOfVectorMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; +}; + +MIDL_INTERFACE("30c99886-38d2-41cb-a615-203fe7d7daac") +IEngine : IUnknown { + STDMETHOD(LoadModel) + (_In_ IModel*) PURE; + + STDMETHOD(Initialize) + () PURE; + + STDMETHOD(RegisterGraphTransformers) + () PURE; + + STDMETHOD(RegisterCustomRegistry) + (IMLOperatorRegistry * registry) PURE; + + STDMETHOD(EndProfiling) + () PURE; + + STDMETHOD(StartProfiling) + () PURE; + + STDMETHOD(FlushContext) + () PURE; + + STDMETHOD(TrimUploadHeap) + () PURE; + + STDMETHOD(ReleaseCompletedReferences) + () PURE; + + STDMETHOD(Sync) + () PURE; + + STDMETHOD(CreateTensorValue) + (const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateTensorValueFromExternalD3DResource) + (ID3D12Resource * resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateTensorValueFromExternalBuffer) + (void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateStringTensorValueFromDataWithCopy) + (const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) PURE; + + STDMETHOD(CreateNullValue) + (_Out_ IValue * *out) PURE; + + STDMETHOD(CreateMapValue) + (IInspectable * map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue * *out) PURE; + + STDMETHOD(CreateSequenceOfMapsValue) + (IInspectable * sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue * *out) PURE; + + STDMETHOD(CreateOneInputAcrossDevices) + (const char* name, IValue* src, IValue** dest) PURE; + + STDMETHOD(CopyValueAcrossDevices) + (IValue * src, IValue * dest) PURE; + + STDMETHOD(Run) + (const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) PURE; + + STDMETHOD(FillFromMapValue) + (IInspectable * map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue * value) PURE; + + STDMETHOD(FillSequenceOfMapsValue) + (IInspectable * sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue * value) PURE; +}; + +MIDL_INTERFACE("0452ef15-b66b-47ca-9eff-aedac571764e") +IEngineBuilder : IUnknown { + STDMETHOD(SetD3D12Resources) + (ID3D12Device * device, ID3D12CommandQueue * queue) PURE; + + STDMETHOD(GetD3D12Device) + (ID3D12Device * *device) PURE; + + STDMETHOD(GetID3D12CommandQueue) + (ID3D12CommandQueue * *queue) PURE; + + STDMETHOD(SetBatchSizeOverride) + (uint32_t batch_size_override) PURE; + + STDMETHOD(CreateEngine) + (IEngine * *out) PURE; +}; + +MIDL_INTERFACE("5eddd25a-70ad-46ef-a445-78fbaf792c2f") +IEngineFactory : IUnknown { + STDMETHOD(CreateModel) + (_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) PURE; + + STDMETHOD(CreateModel) + (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) PURE; + + STDMETHOD(CreateEngineBuilder) + (IEngineBuilder * *engine_builder) PURE; + + STDMETHOD(EnableDebugOutput) + (bool is_enabled) PURE; + + STDMETHOD(CreateCustomRegistry) + (_Out_ IMLOperatorRegistry * *registry) PURE; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Common/inc/onnx.h b/winml/lib/Common/inc/onnx.h index d537ba450b9d8..0db64211dc8e0 100644 --- a/winml/lib/Common/inc/onnx.h +++ b/winml/lib/Common/inc/onnx.h @@ -13,9 +13,6 @@ // Restore ERROR define #define ERROR 0 -// the C++ ort api -#include "core/session/onnxruntime_cxx_api.h" - #ifdef USE_DML #include #endif USE_DML diff --git a/winml/test/api/LearningModelBindingAPITest.cpp b/winml/test/api/LearningModelBindingAPITest.cpp index d67013dfc1db4..2f322987f9be7 100644 --- a/winml/test/api/LearningModelBindingAPITest.cpp +++ b/winml/test/api/LearningModelBindingAPITest.cpp @@ -154,6 +154,8 @@ static void DictionaryVectorizerMapString() WINML_EXPECT_TRUE(first.Current().Key() == mapInputName); WINML_EXPECT_TRUE(first.Current().Value() == mapInputInspectable); WINML_EXPECT_TRUE(binding.Lookup(mapInputName) == mapInputInspectable); + + modelSession.Evaluate(binding, L""); } static void RunZipMapInt64(