diff --git a/BUILD.md b/BUILD.md
index 2cff5ddfe4ff2..febf8ddb73bff 100644
--- a/BUILD.md
+++ b/BUILD.md
@@ -84,6 +84,7 @@ For other system requirements and other dependencies, please see [this section](
|**Build Shared Library**|--build_shared_lib||
|**Build Python wheel**|--build_wheel||
|**Build C# and C packages**|--build_csharp||
+|**Build WindowsML**|--use_winml
--use_dml
--build_shared_lib|WindowsML depends on DirectML and the OnnxRuntime shared library.|
|**Build Java package**|--build_java|Creates an onnxruntime4j.jar in the build directory, implies `--build_shared_lib`|
diff --git a/cgmanifest.json b/cgmanifest.json
index 707d663688d3f..f327e07a08297 100644
--- a/cgmanifest.json
+++ b/cgmanifest.json
@@ -390,7 +390,7 @@
"type": "git"
}
},
- {
+ {
"component": {
"git": {
"commitHash": "e8c599bca6c56c44b6730ad93f6abbc9ecd60fc1",
@@ -399,8 +399,8 @@
"type": "git"
}
},
- {
- "component":{
+ {
+ "component":{
"type": "other",
"Other": {
"Name": "Go",
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index f8002c004e64a..9b80c1c433321 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -90,6 +90,7 @@ option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir"
option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF)
option(onnxruntime_USE_DML "Build with DirectML support" OFF)
+option(onnxruntime_USE_WINML "Build with WinML support" OFF)
option(onnxruntime_USE_ACL "Build with ACL support" OFF)
option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF)
option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF)
@@ -210,10 +211,14 @@ if (MSVC)
SET (CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL")
SET (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL")
endif()
- check_cxx_compiler_flag(-Qspectre HAS_QSPECTRE)
- if (HAS_QSPECTRE)
- SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Qspectre")
- SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Qspectre")
+ # The WinML build tool chain builds ARM/ARM64, and the internal tool chain does not have folders for spectre mitigation libs.
+ # WinML performs spectre mitigation differently.
+ if (NOT DEFINED onnxruntime_DISABLE_QSPECTRE_CHECK)
+ check_cxx_compiler_flag(-Qspectre HAS_QSPECTRE)
+ if (HAS_QSPECTRE)
+ SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Qspectre")
+ SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Qspectre")
+ endif()
endif()
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DYNAMICBASE")
check_cxx_compiler_flag(-guard:cf HAS_GUARD_CF)
@@ -547,9 +552,12 @@ if (WIN32)
# set linker flags to minimize the binary size.
if (MSVC)
- foreach(type EXE SHARED)
- set(CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO} /OPT:REF,ICF,LBR")
- set(CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO} /INCREMENTAL:NO")
+ foreach(type EXE STATIC SHARED)
+ if (NOT type MATCHES STATIC)
+ set(CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO} /OPT:REF,ICF,LBR")
+ set(CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO} /INCREMENTAL:NO")
+ #TODO: the "/LTCG" switch should be controlled by onnxruntime_ENABLE_LTO
+ endif()
if (onnxruntime_ENABLE_LTO AND NOT onnxruntime_USE_CUDA)
set(CMAKE_${type}_LINKER_FLAGS_RELEASE "${CMAKE_${type}_LINKER_FLAGS_RELEASE} /LTCG")
set(CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_${type}_LINKER_FLAGS_RELWITHDEBINFO} /LTCG")
@@ -791,7 +799,7 @@ foreach(target_name onnxruntime_common onnxruntime_graph onnxruntime_framework o
endforeach()
foreach(provider_name ${ONNXRUNTIME_PROVIDER_NAMES})
- if(NOT provider_name STREQUAL "cpu")
+ if(NOT provider_name STREQUAL "cpu" AND NOT provider_name STREQUAL "winml")
if (MSVC)
target_compile_options(onnxruntime_providers_${provider_name} PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>")
target_compile_options(onnxruntime_providers_${provider_name} PRIVATE "$<$:SHELL:--compiler-options /sdl>" "$<$>:/sdl>")
@@ -817,6 +825,18 @@ endif()
+if (onnxruntime_USE_WINML)
+ # WINML uses and depends on the shared lib. Note: You can build WINML without DML and you will get a
+ # CPU only WINML
+ if (NOT onnxruntime_BUILD_SHARED_LIB)
+ message(
+ FATAL_ERROR
+ "Option onnxruntime_USE_WINML can only be used when onnxruntime_BUILD_SHARED_LIB is also enabled")
+ endif()
+ include(wil.cmake)
+ include(winml.cmake)
+endif() # if(onnxruntime_USE_WINML)
+
#The following files may use the 'onnxruntime_libs' and 'onnxruntime_EXTERNAL_LIBRARIES' vars
if (onnxruntime_BUILD_SHARED_LIB)
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index 99e4fa0404859..421b862a1a908 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -19,21 +19,19 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
- set(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/packages)
+ get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.0.0.1)
# Restore nuget packages, which will pull down the DirectML redist package
add_custom_command(
- OUTPUT restore_packages.stamp
+ OUTPUT ${DML_PACKAGE_DIR}/bin/x64/DirectML.lib ${DML_PACKAGE_DIR}/bin/x86/DirectML.lib
DEPENDS ${PACKAGES_CONFIG} ${NUGET_CONFIG}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/nuget/src/nuget restore ${PACKAGES_CONFIG} -PackagesDirectory ${PACKAGES_DIR} -ConfigFile ${NUGET_CONFIG}
- COMMAND ${CMAKE_COMMAND} -E touch restore_packages.stamp
VERBATIM)
- add_custom_target(RESTORE_PACKAGES ALL DEPENDS restore_packages.stamp)
+ include_directories(BEFORE "${DML_PACKAGE_DIR}/include")
+ add_custom_target(RESTORE_PACKAGES ALL DEPENDS ${DML_PACKAGE_DIR}/bin/x64/DirectML.lib ${DML_PACKAGE_DIR}/bin/x86/DirectML.lib)
add_dependencies(RESTORE_PACKAGES nuget)
-
- list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES RESTORE_PACKAGES)
else()
include_directories(${dml_INCLUDE_DIR})
- link_directories(${dml_LIB_DIR})
endif()
diff --git a/cmake/onnx/CMakeLists.txt b/cmake/onnx/CMakeLists.txt
index ca90a7faf7ab4..79177911da1f2 100644
--- a/cmake/onnx/CMakeLists.txt
+++ b/cmake/onnx/CMakeLists.txt
@@ -8,7 +8,7 @@ target_include_directories(onnx_proto PUBLIC $)
onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${ONNXRUNTIME_ROOT}/core/protobuf TARGET onnx_proto)
if (WIN32)
- target_compile_options(onnx_proto PRIVATE "/wd4146" "/wd4125" "/wd4456" "/wd4267")
+ target_compile_options(onnx_proto PRIVATE "/wd4146" "/wd4125" "/wd4456" "/wd4267" "/wd4309")
else()
if(HAS_UNUSED_VARIABLE)
target_compile_options(onnx_proto PRIVATE "-Wno-unused-variable")
@@ -53,6 +53,7 @@ if (WIN32)
/wd4100 # 'param' : unreferenced formal parameter
/wd4244 # 'argument' conversion from 'google::protobuf::int64' to 'int', possible loss of data
/EHsc # exception handling - C++ may throw, extern "C" will not
+ /wd4996 # 'argument' Using double parameter version instead of single parameter version of SetTotalBytesLimit(). The second parameter is ignored.
)
set(onnx_static_library_flags
-IGNORE:4221 # LNK4221: This object file does not define any previously undefined public symbols, so it will not be used by any link operation that consumes this library
diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake
index dfb63dcf32238..6f0e9b0f177d3 100644
--- a/cmake/onnxruntime.cmake
+++ b/cmake/onnxruntime.cmake
@@ -70,6 +70,7 @@ target_link_libraries(onnxruntime PRIVATE
${PROVIDERS_NUPHAR}
${PROVIDERS_DML}
${PROVIDERS_ACL}
+ ${onnxruntime_winml}
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util
diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake
index 63a8283429dce..52aa2f3989d9d 100644
--- a/cmake/onnxruntime_common.cmake
+++ b/cmake/onnxruntime_common.cmake
@@ -44,6 +44,22 @@ else()
endif()
endif()
+if(CMAKE_GENERATOR_PLATFORM)
+ # Multi-platform generator
+ set(onnxruntime_target_platform ${CMAKE_GENERATOR_PLATFORM})
+else()
+ set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR})
+endif()
+if(onnxruntime_target_platform STREQUAL "ARM64")
+ set(onnxruntime_target_platform "ARM64")
+elseif(onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM")
+ set(onnxruntime_target_platform "ARM")
+elseif(onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64")
+ set(onnxruntime_target_platform "x64")
+elseif(onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686")
+ set(onnxruntime_target_platform "x86")
+endif()
+
file(GLOB onnxruntime_common_src CONFIGURE_DEPENDS
${onnxruntime_common_src_patterns}
)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index e7b4213bc6cbb..38ec173dfc0b1 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -19,7 +19,7 @@ set(mlas_common_srcs
)
if(MSVC)
- if(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM64")
+ if(onnxruntime_target_platform STREQUAL "ARM64")
set(asm_filename ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/SgemmKernelNeon.asm)
set(pre_filename ${CMAKE_CURRENT_BINARY_DIR}/SgemmKernelNeon.i)
set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/SgemmKernelNeon.obj)
@@ -38,11 +38,11 @@ if(MSVC)
armasm64.exe ${ARMASM_FLAGS} ${pre_filename} ${obj_filename}
)
set(mlas_platform_srcs ${obj_filename})
- elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM")
+ elseif(onnxruntime_target_platform STREQUAL "ARM")
set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp
)
- elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "x64" OR CMAKE_GENERATOR MATCHES "Win64")
+ elseif(onnxruntime_target_platform STREQUAL "x64")
enable_language(ASM_MASM)
set(mlas_platform_srcs
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 6c42e78b5b94d..16bca1a9266c2 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -69,6 +69,10 @@ if(onnxruntime_USE_DML)
set(PROVIDERS_DML onnxruntime_providers_dml)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES dml)
endif()
+if(onnxruntime_USE_WINML)
+ set(PROVIDERS_WINML onnxruntime_providers_winml)
+ list(APPEND ONNXRUNTIME_PROVIDER_NAMES winml)
+endif()
if(onnxruntime_USE_ACL)
set(PROVIDERS_ACL onnxruntime_providers_acl)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES acl)
@@ -215,7 +219,7 @@ if (onnxruntime_USE_TENSORRT)
if ( CMAKE_COMPILER_IS_GNUCC )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers")
endif()
- set(CXX_VERSION_DEFINED TRUE)
+ set(CXX_VERSION_DEFINED TRUE)
add_subdirectory(${ONNXRUNTIME_ROOT}/../cmake/external/onnx-tensorrt)
set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
if (WIN32)
@@ -301,7 +305,7 @@ if (onnxruntime_USE_OPENVINO)
if(WIN32)
set(OPENVINO_LIB_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/Release)
set(OPENVINO_TBB_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/Release)
- set(OPENVINO_MKL_TINY_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/bin/intel64/Release)
+ set(OPENVINO_MKL_TINY_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/bin/intel64/Release)
else()
set(OPENVINO_LIB_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/)
set(OPENVINO_TBB_DIR $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/tbb/lib)
@@ -325,9 +329,9 @@ if (onnxruntime_USE_OPENVINO)
else()
target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${OPENVINO_INCLUDE_DIR} ${OPENVINO_EXTENSIONS_DIR} ${OPENVINO_LIB_DIR} ${OPENVINO_TBB_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS})
endif()
-
- if (WIN32)
- string(REPLACE "include" "libs" PYTHON_LIB ${PYTHON_INCLUDE_DIRS})
+
+ if (WIN32)
+ string(REPLACE "include" "libs" PYTHON_LIB ${PYTHON_INCLUDE_DIRS})
find_package(InferenceEngine 2.1 REQUIRED)
set(PYTHON_LIBRARIES ${PYTHON_LIB})
set(OPENVINO_CPU_EXTENSION_DIR ${onnxruntime_BINARY_DIR}/ie_cpu_extension/${CMAKE_BUILD_TYPE})
@@ -428,21 +432,41 @@ if (onnxruntime_USE_DML)
onnxruntime_add_include_to_target(onnxruntime_providers_dml onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
add_dependencies(onnxruntime_providers_dml ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_providers_dml PRIVATE ${ONNXRUNTIME_ROOT} ${ONNXRUNTIME_ROOT}/../cmake/external/wil/include)
-
- target_link_libraries(onnxruntime_providers_dml ${CMAKE_CURRENT_BINARY_DIR}/packages/DirectML.0.0.1/build/DirectML.targets)
- target_link_libraries(onnxruntime_providers_dml d3d12.lib dxgi.lib)
+
+ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
+ if(NOT onnxruntime_target_platform STREQUAL "x86" AND NOT onnxruntime_target_platform STREQUAL "x64")
+ message(FATAL_ERROR "Target platform ${onnxruntime_target_platform} is not supported by DML")
+ endif()
+ foreach(file "DirectML.dll" "DirectML.pdb" "DirectML.Debug.dll" "DirectML.Debug.pdb")
+ add_custom_command(TARGET onnxruntime_providers_dml
+ POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different
+ "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/${file}" $)
+ endforeach()
+ endif()
+
+ function(target_add_dml target)
+ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
+ target_link_libraries(${target} PRIVATE "${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/DirectML.lib")
+ add_dependencies(${target} RESTORE_PACKAGES)
+ endif()
+ endfunction()
+
+ target_add_dml(onnxruntime_providers_dml)
+ target_link_libraries(onnxruntime_providers_dml PRIVATE d3d12.lib dxgi.lib delayimp.lib)
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll")
+
# The DML EP requires C++17
set_target_properties(onnxruntime_providers_dml PROPERTIES CXX_STANDARD 17)
set_target_properties(onnxruntime_providers_dml PROPERTIES CXX_STANDARD_REQUIRED ON)
-
+
target_compile_definitions(onnxruntime_providers_dml PRIVATE ONNX_NAMESPACE=onnx ONNX_ML LOTUS_LOG_THRESHOLD=2 LOTUS_ENABLE_STDERR_LOGGING PLATFORM_WINDOWS)
target_compile_definitions(onnxruntime_providers_dml PRIVATE UNICODE _UNICODE NOMINMAX)
if (MSVC)
target_compile_definitions(onnxruntime_providers_dml PRIVATE _SILENCE_CXX17_ITERATOR_BASE_CLASS_DEPRECATION_WARNING)
target_compile_options(onnxruntime_providers_dml PRIVATE "/W3")
endif()
-
+
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dml DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
set_target_properties(onnxruntime_providers_dml PROPERTIES LINKER_LANGUAGE CXX)
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index f6814a34d3b9d..1592f26be2c37 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -340,6 +340,9 @@ onnxruntime_add_include_to_target(onnxruntime_test_utils onnxruntime_framework G
if (onnxruntime_USE_DNNL)
target_compile_definitions(onnxruntime_test_utils PUBLIC USE_DNNL=1)
endif()
+if (onnxruntime_USE_DML)
+ target_add_dml(onnxruntime_test_utils)
+endif()
add_dependencies(onnxruntime_test_utils ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_test_utils PUBLIC "${TEST_SRC_DIR}/util/include" PRIVATE ${eigen_INCLUDE_DIRS} ${ONNXRUNTIME_ROOT})
set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest")
diff --git a/cmake/precompiled_header.cmake b/cmake/precompiled_header.cmake
new file mode 100644
index 0000000000000..dbdeb2bb508aa
--- /dev/null
+++ b/cmake/precompiled_header.cmake
@@ -0,0 +1,29 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+# Configures sources on a target to use a precompiled header. This function takes a target and
+# header name as input. The function will generate a .cpp file that includes the header and is used
+# to generate the precompiled header; this source file is added to the target's sources.
+function(target_precompiled_header target_name header_name)
+ if (MSVC AND CMAKE_VS_PLATFORM_TOOLSET)
+ # The input precompiled header source (i.e. the '.h' file used for the precompiled header).
+ set(pch_header_path ${header_name})
+ get_filename_component(header_base_name ${header_name} NAME_WE)
+
+ # Generate the source file that builds the precompiled header. The generated file will have
+ # the same base name as the input header name, but has the .cpp extension.
+ set(pch_source_path ${CMAKE_CURRENT_BINARY_DIR}/${header_base_name}.cpp)
+ set(pch_source_content "// THIS FILE IS GENERATED BY CMAKE\n#include \"${pch_header_path}\"")
+ file(WRITE ${pch_source_path} ${pch_source_content})
+ set_source_files_properties(${pch_source_path} PROPERTIES COMPILE_FLAGS "/Yc${pch_header_path}")
+
+ # The target's C++ sources use the precompiled header (/Yu). Source-level properties will
+ # take precedence over target-level properties, so this will not change the generated source
+ # file's property to create the precompiled header (/Yc).
+ target_compile_options(${target_name} PRIVATE $<$:/Yu${header_name}>)
+
+ # Append generated precompiled source to target's sources.
+ target_sources(${target_name} PRIVATE ${pch_source_path})
+
+ endif()
+endfunction()
diff --git a/cmake/wil.cmake b/cmake/wil.cmake
new file mode 100644
index 0000000000000..36a8bc9d3cd18
--- /dev/null
+++ b/cmake/wil.cmake
@@ -0,0 +1,5 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+add_library(wil INTERFACE)
+target_include_directories(wil INTERFACE external/wil/include/)
\ No newline at end of file
diff --git a/cmake/winml.cmake b/cmake/winml.cmake
new file mode 100644
index 0000000000000..d3814c59492a1
--- /dev/null
+++ b/cmake/winml.cmake
@@ -0,0 +1,629 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+include(precompiled_header.cmake)
+include(winml_sdk_helpers.cmake)
+include(winml_cppwinrt.cmake)
+
+# get the current nuget sdk kit directory
+get_sdk(sdk_folder sdk_version)
+set(target_folder ONNXRuntime/winml)
+set(winml_adapter_dir ${REPO_ROOT}/winml/adapter)
+set(winml_api_root ${REPO_ROOT}/winml/api)
+set(winml_dll_dir ${REPO_ROOT}/winml/dll)
+set(winml_lib_dir ${REPO_ROOT}/winml/lib)
+set(winml_lib_api_dir ${REPO_ROOT}/winml/lib/api)
+set(winml_lib_api_image_dir ${REPO_ROOT}/winml/lib/api.image)
+set(winml_lib_api_ort_dir ${REPO_ROOT}/winml/lib/api.ort)
+set(winml_lib_common_dir ${REPO_ROOT}/winml/lib/common)
+set(winml_lib_telemetry_dir ${REPO_ROOT}/winml/lib/telemetry)
+
+# Version parts for Windows.AI.MachineLearning.dll.
+set(WINML_VERSION_MAJOR_PART 0 CACHE STRING "First part of numeric file/product version.")
+set(WINML_VERSION_MINOR_PART 0 CACHE STRING "Second part of numeric file/product version.")
+set(WINML_VERSION_BUILD_PART 0 CACHE STRING "Third part of numeric file/product version.")
+set(WINML_VERSION_PRIVATE_PART 0 CACHE STRING "Fourth part of numeric file/product version.")
+set(WINML_VERSION_STRING "Internal Build" CACHE STRING "String representation of file/product version.")
+
+get_filename_component(exclusions "${winml_api_root}/exclusions.txt" ABSOLUTE)
+convert_forward_slashes_to_back(${exclusions} CPPWINRT_COMPONENT_EXCLUSION_LIST)
+
+# For winrt idl files:
+# 1) the file name must match the casing of the file on disk.
+# 2) for winrt idls the casing must match the namespaces within exactly (Window.AI.MachineLearning).
+# target_cppwinrt will attempt to create a winmd with the name and same casing as the supplied
+# idl file. If the name of the winmd file does not match the contained namespaces, cppwinrt.exe
+# will generate component template files with fully qualified names, which will not match the existing
+# generated component files.
+#
+# For native idl files there are no casing restrictions.
+get_filename_component(winrt_idl "${winml_api_root}/Windows.AI.MachineLearning.idl" ABSOLUTE)
+get_filename_component(idl_native "${winml_api_root}/windows.ai.machineLearning.native.idl" ABSOLUTE)
+get_filename_component(idl_native_internal "${winml_api_root}/windows.ai.machineLearning.native.internal.idl" ABSOLUTE)
+
+# generate cppwinrt sdk
+add_generate_cppwinrt_sdk_headers_target(
+ winml_sdk_cppwinrt # the target name
+ ${sdk_folder} # location of sdk folder
+ ${sdk_version} # sdk version
+ ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include # output folder relative to CMAKE_BINARY_DIR where the generated sdk will be placed in the
+ ${target_folder} # folder where this target will be placed
+)
+
+# generate winml headers from idl
+target_cppwinrt(winml_api
+ ${winrt_idl} # winml winrt idl to compile
+ ${winml_lib_api_dir} # location for cppwinrt generated component sources
+ ${sdk_folder} # location of sdk folder
+ ${sdk_version} # sdk version
+ ${target_folder} # the folder this target will be placed under
+)
+
+target_midl(winml_api_native
+ ${idl_native} # winml native idl to compile
+ ${sdk_folder} # location of sdk folder
+ ${sdk_version} # sdk version
+ ${target_folder} # the folder this target will be placed under
+)
+
+target_midl(winml_api_native_internal
+ ${idl_native_internal} # winml internal native idl to compile
+ ${sdk_folder} # location of sdk folder
+ ${sdk_version} # sdk version
+ ${target_folder}) # the folder this target will be placed under
+
+###########################
+# Add winml_lib_telemetry
+###########################
+
+# Add static library that will be archived/linked for both static/dynamic library
+add_library(winml_lib_telemetry STATIC
+ ${winml_lib_telemetry_dir}/inc/TelemetryEvent.h
+ ${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows/TraceLoggingConfig.h
+ ${winml_lib_common_dir}/inc/WinMLTelemetryHelper.h
+ ${winml_lib_telemetry_dir}/Telemetry.cpp
+ ${winml_lib_telemetry_dir}/TelemetryEvent.cpp
+ ${winml_lib_telemetry_dir}/WinMLTelemetryHelper.cpp
+ ${winml_lib_telemetry_dir}/pch.h
+)
+
+# Compiler options
+target_compile_features(winml_lib_telemetry PRIVATE cxx_std_17)
+target_compile_options(winml_lib_telemetry PRIVATE /GR- /await /wd4238)
+if (onnxruntime_USE_TELEMETRY)
+ set_target_properties(winml_lib_telemetry PROPERTIES COMPILE_FLAGS "/FI${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows/TraceLoggingConfigPrivate.h")
+endif()
+
+# Compiler flags
+target_compile_definitions(winml_lib_telemetry PRIVATE PLATFORM_WINDOWS)
+target_compile_definitions(winml_lib_telemetry PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_lib_telemetry pch.h)
+
+# Includes
+target_include_directories(winml_lib_telemetry PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include)
+target_include_directories(winml_lib_telemetry PRIVATE ${CMAKE_SOURCE_DIR}/common/inc)
+target_include_directories(winml_lib_telemetry PRIVATE ${winml_lib_telemetry_dir})
+target_include_directories(winml_lib_telemetry PRIVATE ${winml_lib_common_dir}/inc)
+target_include_directories(winml_lib_telemetry PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows)
+
+# Properties
+set_target_properties(winml_lib_telemetry
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Link libraries
+target_link_libraries(winml_lib_telemetry PRIVATE wil)
+
+###########################
+# Add winml_lib_ort
+###########################
+
+list(APPEND winml_lib_api_ort_files
+ ${winml_lib_api_ort_dir}/inc/OnnxruntimeProvider.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeEngine.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeEngine.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeModel.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeModel.cpp
+ ${winml_lib_api_ort_dir}/OnnxruntimeSessionBuilder.h
+ ${winml_lib_api_ort_dir}/pch.h
+ )
+
+if (onnxruntime_USE_DML)
+ list(APPEND winml_lib_api_ort_files
+ ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.h
+ ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.cpp
+ )
+endif(onnxruntime_USE_DML)
+
+# Add static library that will be archived/linked for both static/dynamic library
+add_library(winml_lib_ort STATIC ${winml_lib_api_ort_files})
+
+# Compiler options
+target_compile_features(winml_lib_ort PRIVATE cxx_std_17)
+target_compile_options(winml_lib_ort PRIVATE /GR- /await /wd4238)
+
+# Compiler definitions
+target_compile_definitions(winml_lib_ort PRIVATE PLATFORM_WINDOWS)
+target_compile_definitions(winml_lib_ort PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_lib_ort pch.h)
+
+# Includes
+target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+
+target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
+
+target_include_directories(winml_lib_ort PRIVATE ${REPO_ROOT}/winml)
+target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_dir}) # needed for generated headers
+target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_core_dir})
+target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_ort_dir})
+target_include_directories(winml_lib_ort PRIVATE ${winml_lib_common_dir}/inc)
+target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_INCLUDE_DIR})
+target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_ROOT})
+
+set_target_properties(winml_lib_ort
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Add deps
+add_dependencies(winml_lib_ort winml_sdk_cppwinrt)
+add_dependencies(winml_lib_ort winml_api)
+add_dependencies(winml_lib_ort winml_api_native)
+add_dependencies(winml_lib_ort winml_api_native_internal)
+
+# Link libraries
+target_link_libraries(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/packages/DirectML.0.0.1/build/DirectML.targets)
+target_link_libraries(winml_lib_ort PRIVATE wil)
+
+
+###########################
+# Add winml_adapter
+###########################
+
+list(APPEND winml_adapter_files
+ ${winml_adapter_dir}/pch.h
+ ${winml_adapter_dir}/winml_adapter_apis.h
+ ${winml_adapter_dir}/winml_adapter_c_api.h
+ ${winml_adapter_dir}/winml_adapter_c_api.cpp
+ ${winml_adapter_dir}/winml_adapter_dml.cpp
+ ${winml_adapter_dir}/winml_adapter_environment.cpp
+ ${winml_adapter_dir}/winml_adapter_execution_provider.cpp
+ ${winml_adapter_dir}/winml_adapter_model.cpp
+ ${winml_adapter_dir}/winml_adapter_model.h
+ ${winml_adapter_dir}/winml_adapter_session.cpp
+ )
+
+if (onnxruntime_USE_DML)
+ list(APPEND winml_adapter_files
+ ${winml_adapter_dir}/abi_custom_registry_impl.cpp
+ ${winml_adapter_dir}/abi_custom_registry_impl.h
+ )
+endif(onnxruntime_USE_DML)
+
+add_library(winml_adapter ${winml_adapter_files})
+
+# wil requires C++17
+set_target_properties(winml_adapter PROPERTIES CXX_STANDARD 17)
+set_target_properties(winml_adapter PROPERTIES CXX_STANDARD_REQUIRED ON)
+
+# Compiler definitions
+onnxruntime_add_include_to_target(winml_adapter onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
+target_include_directories(winml_adapter PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
+add_dependencies(winml_adapter ${onnxruntime_EXTERNAL_DEPENDENCIES})
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_adapter pch.h)
+
+# Includes
+target_include_directories(winml_adapter PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # windows machine learning generated component headers
+target_include_directories(winml_adapter PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_adapter PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_adapter PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+target_include_directories(winml_adapter PRIVATE ${winml_lib_api_dir}) # needed for generated headers
+target_include_directories(winml_adapter PRIVATE ${winml_lib_dir})
+target_include_directories(winml_adapter PRIVATE ${winml_adapter_dir})
+target_include_directories(winml_adapter PRIVATE ${winml_lib_common_dir}/inc)
+
+set_target_properties(winml_adapter
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Add deps
+add_dependencies(winml_adapter winml_sdk_cppwinrt)
+add_dependencies(winml_adapter winml_api)
+add_dependencies(winml_adapter winml_api_native)
+add_dependencies(winml_adapter winml_api_native_internal)
+
+# Link libraries
+target_link_libraries(winml_adapter PRIVATE wil)
+if (onnxruntime_USE_DML)
+ target_add_dml(winml_adapter)
+endif(onnxruntime_USE_DML)
+
+# add it to the onnxruntime shared library
+set(onnxruntime_winml winml_adapter)
+list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES winml_adapter)
+
+###########################
+# Add winml_lib_image
+###########################
+
+# Add static library that will be archived/linked for both static/dynamic library
+add_library(winml_lib_image STATIC
+ ${winml_lib_api_image_dir}/inc/ConverterResourceStore.h
+ ${winml_lib_api_image_dir}/inc/D3DDeviceCache.h
+ ${winml_lib_api_image_dir}/inc/DeviceHelpers.h
+ ${winml_lib_api_image_dir}/inc/ImageConversionHelpers.h
+ ${winml_lib_api_image_dir}/inc/ImageConversionTypes.h
+ ${winml_lib_api_image_dir}/inc/ImageConverter.h
+ ${winml_lib_api_image_dir}/inc/TensorToVideoFrameConverter.h
+ ${winml_lib_api_image_dir}/inc/VideoFrameToTensorConverter.h
+ ${winml_lib_api_image_dir}/CpuDetensorizer.h
+ ${winml_lib_api_image_dir}/CpuTensorizer.h
+ ${winml_lib_api_image_dir}/pch.h
+ ${winml_lib_api_image_dir}/ConverterResourceStore.cpp
+ ${winml_lib_api_image_dir}/D3DDeviceCache.cpp
+ ${winml_lib_api_image_dir}/DeviceHelpers.cpp
+ ${winml_lib_api_image_dir}/ImageConversionHelpers.cpp
+ ${winml_lib_api_image_dir}/ImageConverter.cpp
+ ${winml_lib_api_image_dir}/TensorToVideoFrameConverter.cpp
+ ${winml_lib_api_image_dir}/VideoFrameToTensorConverter.cpp
+)
+
+# Compiler options
+target_compile_features(winml_lib_image PRIVATE cxx_std_17)
+target_compile_options(winml_lib_image PRIVATE /GR- /await /wd4238)
+
+# Compiler flags
+target_compile_definitions(winml_lib_image PRIVATE ONNX_NAMESPACE=onnx)
+target_compile_definitions(winml_lib_image PRIVATE ONNX_ML)
+target_compile_definitions(winml_lib_image PRIVATE LOTUS_LOG_THRESHOLD=2)
+target_compile_definitions(winml_lib_image PRIVATE LOTUS_ENABLE_STDERR_LOGGING)
+target_compile_definitions(winml_lib_image PRIVATE PLATFORM_WINDOWS)
+target_compile_definitions(winml_lib_image PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_lib_image pch.h)
+
+# Includes
+target_include_directories(winml_lib_image PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # windows machine learning generated component headers
+target_include_directories(winml_lib_image PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_lib_image PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_lib_image PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+target_include_directories(winml_lib_image PRIVATE ${ONNXRUNTIME_ROOT}/core/providers/dml/DmlExecutionProvider/src/External/D3DX12) # for d3dx12.h
+target_include_directories(winml_lib_image PRIVATE ${winml_lib_api_dir}) # needed for generated headers
+target_include_directories(winml_lib_image PRIVATE ${winml_lib_api_image_dir})
+target_include_directories(winml_lib_image PRIVATE ${winml_lib_common_dir}/inc)
+target_include_directories(winml_lib_image PRIVATE ${ONNXRUNTIME_ROOT})
+target_include_directories(winml_lib_image PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}) # for status.h
+target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/gsl/include)
+target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/onnx)
+target_include_directories(winml_lib_image PRIVATE ${REPO_ROOT}/cmake/external/protobuf/src)
+
+# Properties
+set_target_properties(winml_lib_image
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Add deps
+add_dependencies(winml_lib_image winml_sdk_cppwinrt)
+add_dependencies(winml_lib_image winml_api)
+add_dependencies(winml_lib_image winml_api_native)
+add_dependencies(winml_lib_image winml_api_native_internal)
+
+# Link libraries
+target_link_libraries(winml_lib_image PRIVATE wil winml_lib_common)
+if (onnxruntime_USE_DML)
+ target_add_dml(winml_lib_image)
+endif(onnxruntime_USE_DML)
+
+
+###########################
+# Add winml_lib_api
+###########################
+
+# Add static library that will be archived/linked for both static/dynamic library
+add_library(winml_lib_api STATIC
+ ${winml_lib_api_dir}/impl/FeatureCompatibility.h
+ ${winml_lib_api_dir}/impl/IMapFeatureValue.h
+ ${winml_lib_api_dir}/impl/ISequenceFeatureValue.h
+ ${winml_lib_api_dir}/impl/MapBase.h
+ ${winml_lib_api_dir}/impl/SequenceBase.h
+ ${winml_lib_api_dir}/impl/Tensor.h
+ ${winml_lib_api_dir}/impl/TensorBase.h
+ ${winml_lib_api_dir}/impl/TensorBuffer.h
+ ${winml_lib_api_dir}/impl/TensorKindFrom.h
+ ${winml_lib_api_dir}/impl/TensorMemoryBufferReference.h
+ ${winml_lib_api_dir}/ImageFeatureDescriptor.cpp
+ ${winml_lib_api_dir}/ImageFeatureDescriptor.h
+ ${winml_lib_api_dir}/ImageFeatureValue.cpp
+ ${winml_lib_api_dir}/ImageFeatureValue.h
+ ${winml_lib_api_dir}/LearningModel.cpp
+ ${winml_lib_api_dir}/LearningModel.h
+ ${winml_lib_api_dir}/LearningModelBinding.cpp
+ ${winml_lib_api_dir}/LearningModelBinding.h
+ ${winml_lib_api_dir}/LearningModelDevice.cpp
+ ${winml_lib_api_dir}/LearningModelDevice.h
+ ${winml_lib_api_dir}/LearningModelEvaluationResult.cpp
+ ${winml_lib_api_dir}/LearningModelEvaluationResult.h
+ ${winml_lib_api_dir}/LearningModelSession.cpp
+ ${winml_lib_api_dir}/LearningModelSession.h
+ ${winml_lib_api_dir}/LearningModelSessionOptions.cpp
+ ${winml_lib_api_dir}/LearningModelSessionOptions.h
+ ${winml_lib_api_dir}/MapFeatureDescriptor.cpp
+ ${winml_lib_api_dir}/MapFeatureDescriptor.h
+ ${winml_lib_api_dir}/SequenceFeatureDescriptor.cpp
+ ${winml_lib_api_dir}/SequenceFeatureDescriptor.h
+ ${winml_lib_api_dir}/TensorFeatureDescriptor.cpp
+ ${winml_lib_api_dir}/TensorFeatureDescriptor.h
+ ${winml_lib_api_dir}/pch/pch.h
+)
+
+# Compiler options
+target_compile_features(winml_lib_api PRIVATE cxx_std_17)
+target_compile_options(winml_lib_api PRIVATE /GR- /await /bigobj /wd4238)
+
+# Compiler flags
+target_compile_definitions(winml_lib_api PRIVATE ONNX_NAMESPACE=onnx)
+target_compile_definitions(winml_lib_api PRIVATE ONNX_ML)
+target_compile_definitions(winml_lib_api PRIVATE LOTUS_LOG_THRESHOLD=2)
+target_compile_definitions(winml_lib_api PRIVATE LOTUS_ENABLE_STDERR_LOGGING)
+target_compile_definitions(winml_lib_api PRIVATE PLATFORM_WINDOWS)
+target_compile_definitions(winml_lib_api PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_lib_api pch.h)
+
+# Includes
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir})
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir}/pch)
+target_include_directories(winml_lib_api PRIVATE ${winml_adapter_dir})
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_image_dir}/inc)
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_ort_dir}/inc)
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_telemetry_dir}/inc)
+target_include_directories(winml_lib_api PRIVATE ${winml_lib_common_dir}/inc)
+
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/date/include)
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/gsl/include)
+target_include_directories(winml_lib_api PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/onnx)
+
+target_include_directories(winml_lib_api PRIVATE ${ONNXRUNTIME_INCLUDE_DIR})
+target_include_directories(winml_lib_api PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}/core/graph)
+target_include_directories(winml_lib_api PRIVATE ${ONNXRUNTIME_ROOT})
+target_include_directories(winml_lib_api PRIVATE ${ONNXRUNTIME_ROOT}/core/graph)
+target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/eigen)
+target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/onnx)
+target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/protobuf/src)
+target_include_directories(winml_lib_api PRIVATE ${REPO_ROOT}/cmake/external/gsl/include)
+
+# Properties
+set_target_properties(winml_lib_api
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Add deps
+add_dependencies(winml_lib_api onnx)
+add_dependencies(winml_lib_api winml_sdk_cppwinrt)
+add_dependencies(winml_lib_api winml_api)
+add_dependencies(winml_lib_api winml_api_native)
+add_dependencies(winml_lib_api winml_api_native_internal)
+
+# Link libraries
+target_link_libraries(winml_lib_api PRIVATE wil winml_lib_telemetry)
+if (onnxruntime_USE_DML)
+ target_add_dml(winml_lib_api)
+endif(onnxruntime_USE_DML)
+
+###########################
+# Add winml_lib_common
+###########################
+
+add_library(winml_lib_common STATIC
+ ${winml_lib_common_dir}/inc/common.h
+ ${winml_lib_common_dir}/inc/CommonDeviceHelpers.h
+ ${winml_lib_common_dir}/inc/cppwinrt_onnx.h
+ ${winml_lib_common_dir}/inc/dx.h
+ ${winml_lib_common_dir}/inc/errors.h
+ ${winml_lib_common_dir}/inc/iengine.h
+ ${winml_lib_common_dir}/inc/NamespaceAliases.h
+ ${winml_lib_common_dir}/inc/onnx.h
+ ${winml_lib_common_dir}/inc/PheonixSingleton.h
+ ${winml_lib_common_dir}/inc/StringHelpers.h
+ ${winml_lib_common_dir}/inc/WinMLTelemetryHelper.h
+ ${winml_lib_common_dir}/inc/WinML_Lock.h
+ ${winml_lib_common_dir}/inc/winrt_headers.h
+ ${winml_lib_common_dir}/CommonDeviceHelpers.cpp
+)
+
+set_target_properties(winml_lib_common PROPERTIES CXX_STANDARD 17)
+set_target_properties(winml_lib_common PROPERTIES CXX_STANDARD_REQUIRED ON)
+target_compile_options(winml_lib_common PRIVATE /GR- /await /bigobj /wd4238)
+target_link_libraries(winml_lib_common PRIVATE wil)
+target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api)
+target_compile_definitions(winml_lib_common PRIVATE
+ ONNX_NAMESPACE=onnx
+ ONNX_ML
+ LOTUS_LOG_THRESHOLD=2
+ LOTUS_ENABLE_STDERR_LOGGING
+ PLATFORM_WINDOWS
+ _SCL_SECURE_NO_WARNINGS)
+add_dependencies(winml_lib_common winml_sdk_cppwinrt)
+add_dependencies(winml_lib_common winml_api)
+add_dependencies(winml_lib_common winml_api_native)
+add_dependencies(winml_lib_common winml_api_native_internal)
+
+target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+target_include_directories(winml_lib_common PRIVATE ${winml_lib_api_dir})
+target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
+target_include_directories(winml_lib_common PRIVATE ${winml_lib_common_dir}/inc)
+target_precompiled_header(winml_lib_common inc/pch.h)
+
+if (onnxruntime_USE_DML)
+ target_add_dml(winml_lib_common)
+endif()
+
+###########################
+# Add winml_dll
+###########################
+
+set_source_files_properties(
+ ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated/module.g.excl.cpp
+ PROPERTIES
+ GENERATED
+ TRUE)
+
+# Add library
+add_library(winml_dll SHARED
+ ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated/module.g.excl.cpp
+ ${winml_dll_dir}/windows.ai.machinelearning.def
+ ${winml_dll_dir}/winml.rc
+ ${winml_dll_dir}/pch.h
+ ${winml_dll_dir}/module.cpp
+)
+
+# Compiler options
+target_compile_features(winml_dll PRIVATE cxx_std_17)
+target_compile_options(winml_dll PRIVATE /GR- /await /bigobj /wd4238)
+
+# Compiler definitions
+target_compile_definitions(winml_dll PRIVATE ONNX_NAMESPACE=onnx)
+target_compile_definitions(winml_dll PRIVATE ONNX_ML)
+target_compile_definitions(winml_dll PRIVATE LOTUS_LOG_THRESHOLD=2)
+target_compile_definitions(winml_dll PRIVATE LOTUS_ENABLE_STDERR_LOGGING)
+target_compile_definitions(winml_dll PRIVATE PLATFORM_WINDOWS)
+target_compile_definitions(winml_dll PRIVATE VER_MAJOR=${WINML_VERSION_MAJOR_PART})
+target_compile_definitions(winml_dll PRIVATE VER_MINOR=${WINML_VERSION_MINOR_PART})
+target_compile_definitions(winml_dll PRIVATE VER_BUILD=${WINML_VERSION_BUILD_PART})
+target_compile_definitions(winml_dll PRIVATE VER_PRIVATE=${WINML_VERSION_PRIVATE_PART})
+target_compile_definitions(winml_dll PRIVATE VER_STRING=\"${WINML_VERSION_STRING}\")
+
+# Specify the usage of a precompiled header
+target_precompiled_header(winml_dll pch.h)
+
+# Includes
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers
+
+target_include_directories(winml_dll PRIVATE ${winml_dll_dir})
+target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir})
+target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir}/impl)
+target_include_directories(winml_dll PRIVATE ${winml_lib_api_ort_dir}/inc)
+target_include_directories(winml_dll PRIVATE ${winml_adapter_dir})
+target_include_directories(winml_dll PRIVATE ${winml_lib_api_image_dir}/inc)
+target_include_directories(winml_dll PRIVATE ${winml_lib_telemetry_dir}/inc)
+target_include_directories(winml_dll PRIVATE ${winml_lib_common_dir}/inc)
+
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/date/include)
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/gsl/include)
+target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/external/onnx)
+
+target_include_directories(winml_dll PRIVATE ${ONNXRUNTIME_INCLUDE_DIR})
+target_include_directories(winml_dll PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}/core/graph)
+target_include_directories(winml_dll PRIVATE ${ONNXRUNTIME_ROOT})
+target_include_directories(winml_dll PRIVATE ${ONNXRUNTIME_ROOT}/core/graph)
+target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/onnx)
+target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/protobuf/src)
+target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/gsl/include)
+target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/eigen)
+
+# Properties
+set_target_properties(winml_dll
+ PROPERTIES
+ OUTPUT_NAME windows.ai.machinelearning)
+
+if (onnxruntime_USE_DML)
+ set(delayload_dml "/DELAYLOAD:directml.dll")
+endif(onnxruntime_USE_DML)
+
+# The default libraries to link with in Windows are kernel32.lib;user32.lib;gdi32.lib;winspool.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;comdlg32.lib;advapi32.lib
+# Remove them and use the onecore umbrella library instead
+foreach(default_lib kernel32.lib user32.lib gdi32.lib winspool.lib shell32.lib ole32.lib oleaut32.lib uuid.lib comdgl32.lib advapi32.lib)
+ set(removed_libs "${removed_libs} /NODEFAULTLIB:${default_lib}")
+endforeach()
+set(CMAKE_C_STANDARD_LIBRARIES "${removed_libs} onecoreuap.lib")
+set(CMAKE_CXX_STANDARD_LIBRARIES "${removed_libs} onecoreuap.lib")
+set_target_properties(winml_dll
+ PROPERTIES
+ LINK_FLAGS
+ "/DEF:${WINML_DIR}/windows.ai.machinelearning.def ${os_component_link_flags} /DELAYLOAD:d3d12.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll ${delayload_dml}")
+
+
+set_target_properties(winml_dll
+ PROPERTIES
+ FOLDER
+ ${target_folder})
+
+# Add deps
+add_dependencies(winml_dll winml_sdk_cppwinrt)
+add_dependencies(winml_dll winml_api_native)
+add_dependencies(winml_dll winml_api_native_internal)
+
+# Any project that links in debug_alloc.obj needs this lib.
+# unresolved external symbol __imp_SymSetOptions
+# ... __imp_SymGetLineFromAddr64
+# ... __imp_SymInitialize
+# ... __imp_SymFromAddr
+if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
+ set(DBGHELP dbghelp.lib)
+endif("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
+
+# Link libraries
+target_link_libraries(winml_dll PRIVATE onnxruntime)
+target_link_libraries(winml_dll PRIVATE re2)
+target_link_libraries(winml_dll PRIVATE wil)
+target_link_libraries(winml_dll PRIVATE winml_lib_api)
+target_link_libraries(winml_dll PRIVATE winml_lib_image)
+target_link_libraries(winml_dll PRIVATE winml_lib_ort)
+target_link_libraries(winml_dll PRIVATE winml_lib_telemetry)
+target_link_libraries(winml_dll PRIVATE delayimp.lib)
+target_link_libraries(winml_dll PRIVATE ${DBGHELP})
+
+# 1 of 3 projects that fail in link with 'failed to do memory mapped file I/O' (Only release)
+# when using x86 hosted architecture. When using the LKG compiler this becomes a problem
+# because it falls back to incorrectly using the public version of link.
+# To avoid the scenario completely, this will tell cl/link to already start with x64 hosted,
+# rather than waiting for it to fail and retry and resolve incorrectly.
+if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
+ set_target_properties(winml_dll PROPERTIES VS_GLOBAL_PreferredToolArchitecture "x64")
+endif("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
+
+option(onnxruntime_BUILD_WINML_TESTS "Build WinML tests" ON)
+if (onnxruntime_BUILD_WINML_TESTS)
+ include(winml_unittests.cmake)
+endif()
+
+# This is needed to suppress warnings that complain that no imports are found for the delayloaded library cublas64*.lib
+# When cuda is enabled in the pipeline, it sets CMAKE_SHARED_LINKER_FLAGS which affects all targets including winml_dll.
+# However, there are no cuda imports in winml_dll, and the linker throws the 4199 warning.
+# This is needed to allow winml_dll build with cuda enabled.
+set_target_properties(winml_dll
+ PROPERTIES
+ LINK_FLAGS
+ "/ignore:4199")
\ No newline at end of file
diff --git a/cmake/winml_cppwinrt.cmake b/cmake/winml_cppwinrt.cmake
new file mode 100644
index 0000000000000..c047689b32588
--- /dev/null
+++ b/cmake/winml_cppwinrt.cmake
@@ -0,0 +1,223 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+# This script adds cppwinrt support for VS-generated projects.
+#
+# target_cppwinrt(foo bar.idl)
+#
+# Calling target_midl function runs midlrt.exe and produces bar.h
+# Calling target_cppwinrt function does two things:
+#
+# 1) Adds a target "bar.cppwinrt", which performs the midl and cppwinrt
+# builds and produces:
+# bar.h
+# bar.winmd
+# bar.tlb
+# module.g.cpp
+#
+# 2) Adds a dependency to the new custom target "bar.cppwinrt"
+
+function(target_midl
+ target_name
+ idl_file
+ sdk_folder # sdk kit directory
+ sdk_version # sdk version
+ folder_name)
+ if (MSVC)
+ # get sdk include paths for midl
+ get_sdk_include_folder(${sdk_folder} ${sdk_version} sdk_include_folder)
+ set(um_sdk_directory "${sdk_include_folder}/um")
+ set(shared_sdk_directory "${sdk_include_folder}/shared")
+ set(winrt_sdk_directory "${sdk_include_folder}/winrt")
+
+ # get sdk metadata path
+ get_sdk_metadata_folder(${sdk_folder} ${sdk_version} sdk_metadata_directory_forward_slashes)
+ convert_forward_slashes_to_back(${sdk_metadata_directory_forward_slashes} sdk_metadata_directory)
+
+ # get midl
+ get_sdk_midl_exe(${sdk_folder} ${sdk_version} midl_exe)
+
+ # Filename variables
+ get_filename_component(file_name_with_extension ${idl_file} NAME)
+ string(REGEX REPLACE "\\.[^.]*$" "" file_name ${file_name_with_extension})
+ set(header_filename ${file_name}.h)
+ convert_forward_slashes_to_back(${idl_file} idl_file_forward_slash)
+
+ # using add_custom_command trick to prevent rerunning script unless ${file} is changed
+ add_custom_command(
+ OUTPUT ${header_filename}
+ COMMAND ${midl_exe}
+ /metadata_dir ${sdk_metadata_directory}
+ /W1 /char signed /nologo /winrt
+ /no_settings_comment /no_def_idir /target "NT60"
+ /I ${um_sdk_directory}
+ /I ${shared_sdk_directory}
+ /I ${winrt_sdk_directory}
+ /I ${CMAKE_CURRENT_SOURCE_DIR}
+ /h ${header_filename}
+ ${idl_file_forward_slash}
+ DEPENDS ${idl_file}
+ )
+
+ add_custom_target(
+ ${target_name}
+ ALL
+ DEPENDS ${header_filename}
+ )
+
+ set_target_properties(${target_name} PROPERTIES FOLDER ${folder_name})
+ endif()
+endfunction()
+
+function(target_cppwinrt
+ target_name # the name of the target to add
+ file # name of the idl file to compile
+ out_sources_folder # path where generated sources will be placed
+ sdk_folder # sdk kit directory
+ sdk_version # sdk version
+ folder_name # folder this target will be placed
+)
+ if (MSVC)
+ # get sdk include paths for midl
+ get_sdk_include_folder(${sdk_folder} ${sdk_version} sdk_include_folder)
+ set(um_sdk_directory "${sdk_include_folder}/um")
+ set(shared_sdk_directory "${sdk_include_folder}/shared")
+ set(winrt_sdk_directory "${sdk_include_folder}/winrt")
+
+ # get sdk metadata path
+ get_sdk_metadata_folder(${sdk_folder} ${sdk_version} sdk_metadata_directory_forward_slashes)
+ convert_forward_slashes_to_back(${sdk_metadata_directory_forward_slashes} sdk_metadata_directory)
+
+ # get midl
+ get_sdk_midl_exe(${sdk_folder} ${sdk_version} midl_exe)
+
+ # get cppwinrt
+ get_sdk_cppwinrt_exe(${sdk_folder} ${sdk_version} cppwinrt_exe)
+
+ # Filename variables
+ convert_forward_slashes_to_back(${file} idl_file_forward_slash)
+ get_filename_component(file_name_with_extension ${file} NAME)
+ string(REGEX REPLACE "\\.[^.]*$" "" fileName ${file_name_with_extension})
+ set(header_filename ${fileName}.h)
+ set(winmd_filename ${fileName}.winmd)
+ set(tlb_filename ${fileName}.tlb)
+
+ # Get directory
+ get_filename_component(idl_source_directory ${file} DIRECTORY)
+
+ set(target_outputs ${CMAKE_CURRENT_BINARY_DIR}/${target_name})
+ convert_forward_slashes_to_back(${target_outputs}/comp output_dir_back_slash)
+ convert_forward_slashes_to_back(${target_outputs}/temp temp_dir_back_slash)
+ convert_forward_slashes_to_back(${target_outputs}/comp_generated generated_dir_back_slash)
+ convert_forward_slashes_to_back(${generated_dir_back_slash}/module.g.cpp module_g_cpp_back_slash)
+ convert_forward_slashes_to_back(${generated_dir_back_slash}/module.g.excl.cpp module_g_ecxl_cpp_back_slash)
+
+ # using add_custom_command trick to prevent rerunning script unless ${file} is changed
+ add_custom_command(
+ OUTPUT ${header_filename} ${winmd_filename}
+ DEPENDS ${file}
+ COMMAND ${midl_exe}
+ /metadata_dir ${sdk_metadata_directory}
+ /W1 /char signed /nomidl /nologo /winrt
+ /no_settings_comment /no_def_idir /target "NT60"
+ /I ${um_sdk_directory}
+ /I ${shared_sdk_directory}
+ /I ${winrt_sdk_directory}
+ /I ${idl_source_directory}
+ /winmd ${winmd_filename}
+ /h ${header_filename}
+ /tlb ${tlb_filename}
+ ${idl_file_forward_slash}
+ COMMAND
+ ${cppwinrt_exe} -in ${winmd_filename} -comp ${output_dir_back_slash} -ref ${sdk_metadata_directory} -out ${generated_dir_back_slash} -verbose
+ COMMAND
+ # copy the generated component files into a temporary directory where headers exclusions will be applied
+ xcopy ${output_dir_back_slash} ${temp_dir_back_slash}\\ /Y /D
+ COMMAND
+ # for each file in the temp directory, ensure it is not in the exclusions list.
+ # if it is, then we need to delete it.
+ cmd /C "@echo off \
+ for /f %I in ('dir /b ${temp_dir_back_slash}') \
+ do \
+ ( \
+ for /f %E in (${CPPWINRT_COMPONENT_EXCLUSION_LIST}) \
+ do \
+ ( \
+ if %E == %I \
+ ( \
+ del ${temp_dir_back_slash}\\%I \
+ ) \
+ ) \
+ )"
+ COMMAND
+ # for each file in the temp directory, copy the file back into the source tree
+ # unless the file already exists
+ cmd /C "@echo off \
+ for /f %I in ('dir /b ${temp_dir_back_slash}') \
+ do \
+ ( \
+ if not exist ${out_sources_folder}\\%I \
+ ( \
+ copy ${temp_dir_back_slash}\\%I ${out_sources_folder}\\%I \
+ ) \
+ )"
+ COMMAND
+ # open the generated module.g.cpp and strip all the includes (lines) containing excluded headers
+ # write the new file out to module.g.excl.cpp.
+ powershell -Command "& { \
+ $exclusions = get-content '${CPPWINRT_COMPONENT_EXCLUSION_LIST}'; \
+ (get-content '${module_g_cpp_back_slash}') \
+ | where { \
+ $str = $_; \
+ $matches = ($exclusions | where { $str -match $_ }); \
+ $matches.Length -eq 0 } \
+ | Out-File '${module_g_ecxl_cpp_back_slash}' \
+ }"
+ BYPRODUCTS
+ ${generated_dir_back_slash}/module.g.excl.cpp
+ VERBATIM
+ )
+
+ add_custom_target(
+ ${target_name}
+ ALL
+ DEPENDS ${header_filename} ${winmd_filename}
+ )
+
+ set_target_properties(${target_name} PROPERTIES FOLDER ${folder_name})
+ endif()
+endfunction()
+
+function(add_generate_cppwinrt_sdk_headers_target
+ target_name # the name of the target to add
+ sdk_folder # sdk kit directory
+ sdk_version # sdk version
+ sdk_directory # the name of the folder to output the sdk headers to
+ folder_name # folder this target will be placed
+)
+ if (MSVC)
+ # get the current nuget sdk's metadata directory
+ get_sdk_metadata_folder(${sdk_folder} ${sdk_version} metadata_folder)
+
+ # get cppwinrt
+ get_sdk_cppwinrt_exe(${sdk_folder} ${sdk_version} cppwinrt_exe)
+
+ # windows.winmd is consumed by cppwinrt to produce the sdk headers
+ set(windows_winmd "${metadata_folder}/windows.winmd")
+
+ # base.h along with the other winrt sdk headers are produced by this command
+ set(base_h "${sdk_directory}/winrt/base.h")
+
+ # using add_custom_command trick to prevent rerunning script unless ${windows_winmd} is changed
+ add_custom_command(
+ OUTPUT ${base_h}
+ DEPENDS ${windows_winmd}
+ COMMAND ${cppwinrt_exe} -in \"${metadata_folder}\" -out \"${sdk_directory}\" -verbose
+ )
+
+ # add the target
+ add_custom_target(${target_name} ALL DEPENDS ${base_h})
+
+ set_target_properties(${target_name} PROPERTIES FOLDER ${folder_name})
+ endif()
+endfunction()
diff --git a/cmake/winml_sdk_helpers.cmake b/cmake/winml_sdk_helpers.cmake
new file mode 100644
index 0000000000000..9241fcd060caf
--- /dev/null
+++ b/cmake/winml_sdk_helpers.cmake
@@ -0,0 +1,120 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+cmake_minimum_required(VERSION 3.0)
+
+# utility
+function(convert_forward_slashes_to_back input output)
+ string(REGEX REPLACE "/" "\\\\" backwards ${input})
+ set(${output} ${backwards} PARENT_SCOPE)
+endfunction()
+
+# get window 10 install path from registry
+function(get_installed_sdk
+ sdk_folder # the current sdk folder
+ output_sdk_version # the current sdk version
+)
+ # return the kit path
+ get_filename_component(win10_sdk_root "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows Kits\\Installed Roots;KitsRoot10]" ABSOLUTE CACHE)
+ set(${sdk_folder} ${win10_sdk_root} PARENT_SCOPE)
+
+ # return the sdk version
+ if(CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION)
+ set(${output_sdk_version} ${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION} PARENT_SCOPE)
+ else()
+ # choose the SDK matching the system version, or fallback to the latest
+ file(GLOB win10_sdks RELATIVE "${win10_sdk_root}/UnionMetadata" "${win10_sdk_root}/UnionMetadata/*.*.*.*")
+ list(GET win10_sdks 0 latest_sdk)
+ foreach(sdk IN LISTS win10_sdks)
+ string(FIND ${sdk} ${CMAKE_SYSTEM_VERSION} is_system_version)
+ if(NOT ${is_system_version} EQUAL -1)
+ set(${output_sdk_version} ${sdk} PARENT_SCOPE)
+ return()
+ elseif(sdk VERSION_GREATER latest_sdk)
+ set(latest_sdk ${sdk})
+ endif()
+ endforeach()
+ set(${output_sdk_version} ${latest_sdk} PARENT_SCOPE)
+ endif()
+endfunction()
+
+# current sdk binary directory
+function(get_sdk_binary_directory
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ binary_dir # the output folder variable
+)
+ set(${binary_dir} "${sdk_folder}/bin/${sdk_version}" PARENT_SCOPE)
+endfunction()
+
+# current sdk include directory
+function(get_sdk_include_folder
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ include_dir # the output folder variable
+)
+ set(${include_dir} "${sdk_folder}/include/${sdk_version}" PARENT_SCOPE)
+endfunction()
+
+# current sdk metadata directory
+function(get_sdk_metadata_folder
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ metadata_dir # the output folder variable
+)
+ set(${metadata_dir} "${sdk_folder}/UnionMetadata/${sdk_version}" PARENT_SCOPE)
+endfunction()
+
+# current sdk midl exe path
+function(get_sdk_midl_exe
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ midl_exe_path # the output exe path
+)
+ get_sdk_binary_directory(${sdk_folder} ${sdk_version} bin_dir)
+ set(${midl_exe_path} "${bin_dir}/x64/midlrt.exe" PARENT_SCOPE)
+endfunction()
+
+# current cppwinrt cppwinrt exe path
+function(get_installed_sdk_cppwinrt_exe
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ cppwinrt_exe_path # the output exe path
+)
+ get_sdk_binary_directory(${sdk_folder} ${sdk_version} bin_dir)
+ set(${cppwinrt_exe_path} "${bin_dir}/x64/cppwinrt.exe" PARENT_SCOPE)
+endfunction()
+
+# current cppwinrt cppwinrt exe path
+function(get_sdk_cppwinrt_exe
+ sdk_folder # the kit path
+ sdk_version # the sdk version
+ output_cppwinrt_exe_path # the output exe path
+)
+ if (NOT DEFINED winml_CPPWINRT_EXE_PATH_OVERRIDE)
+ get_installed_sdk_cppwinrt_exe(${sdk_folder} ${sdk_version} cppwinrt_exe_path)
+ set(${output_cppwinrt_exe_path} ${cppwinrt_exe_path} PARENT_SCOPE)
+ else ()
+ set(${output_cppwinrt_exe_path} ${winml_CPPWINRT_EXE_PATH_OVERRIDE} PARENT_SCOPE)
+ endif()
+endfunction()
+
+function(get_sdk
+ output_sdk_folder # the path to the current sdk kit folder
+ output_sdk_version # the current sdk version
+)
+ if ((NOT DEFINED winml_WINDOWS_SDK_DIR_OVERRIDE) AND
+ (NOT DEFINED winml_WINDOWS_SDK_VERSION_OVERRIDE))
+ get_installed_sdk(sdk_folder sdk_version)
+ set(${output_sdk_folder} ${sdk_folder} PARENT_SCOPE)
+ set(${output_sdk_version} ${sdk_version} PARENT_SCOPE)
+ elseif ((DEFINED winml_WINDOWS_SDK_DIR_OVERRIDE) AND
+ (DEFINED winml_WINDOWS_SDK_VERSION_OVERRIDE))
+ set(${output_sdk_folder} ${winml_WINDOWS_SDK_DIR_OVERRIDE} PARENT_SCOPE)
+ set(${output_sdk_version} ${winml_WINDOWS_SDK_VERSION_OVERRIDE} PARENT_SCOPE)
+ else()
+ message(
+ FATAL_ERROR
+ "Options winml_WINDOWS_SDK_DIR_OVERRIDE and winml_WINDOWS_SDK_VERSION_OVERRIDE must be defined together, or not at all.")
+ endif()
+endfunction()
diff --git a/cmake/winml_unittests.cmake b/cmake/winml_unittests.cmake
new file mode 100644
index 0000000000000..8e35f7e75bde8
--- /dev/null
+++ b/cmake/winml_unittests.cmake
@@ -0,0 +1,132 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+set(WINML_TEST_SRC_DIR ${REPO_ROOT}/winml/test)
+set(WINML_TEST_INC_DIR
+ ${REPO_ROOT}/winml/test/common
+ ${REPO_ROOT}/winml/lib/Common/inc
+ ${REPO_ROOT}/onnxruntime
+ ${REPO_ROOT}/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/D3DX12
+ ${REPO_ROOT}/cmake/external/googletest/googletest/include
+ ${REPO_ROOT}/cmake/external/protobuf/src
+ ${REPO_ROOT}/cmake/external/wil/include
+ ${CMAKE_CURRENT_BINARY_DIR}
+ ${CMAKE_CURRENT_BINARY_DIR}/winml_api
+ ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated
+ ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include)
+
+function(set_winml_target_properties target)
+ set_target_properties(${target} PROPERTIES
+ FOLDER "ONNXRuntimeTest/winml"
+ CXX_STANDARD 17
+ CXX_STANDARD_REQUIRED YES
+ CXX_EXTENSIONS NO
+ )
+ target_include_directories(${target} PRIVATE ${WINML_TEST_INC_DIR})
+endfunction()
+
+function(add_winml_test)
+ # Add a test target and make it discoverable by CTest by calling add_test
+ cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS" ${ARGN})
+ if(_UT_LIBS)
+ list(REMOVE_DUPLICATES _UT_LIBS)
+ endif()
+ list(REMOVE_DUPLICATES _UT_SOURCES)
+ if (_UT_DEPENDS)
+ list(REMOVE_DUPLICATES _UT_DEPENDS)
+ endif()
+
+ add_executable(${_UT_TARGET} ${_UT_SOURCES})
+ source_group(TREE ${WINML_TEST_SRC_DIR} FILES ${_UT_SOURCES})
+ set_winml_target_properties(${_UT_TARGET})
+
+ if (_UT_DEPENDS)
+ add_dependencies(${_UT_TARGET} ${_UT_DEPENDS})
+ endif()
+ target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} gtest winml_google_test_lib ${onnxruntime_EXTERNAL_LIBRARIES} winml_lib_common onnxruntime)
+
+ add_test(NAME ${_UT_TARGET}
+ COMMAND ${_UT_TARGET}
+ WORKING_DIRECTORY $
+ )
+endfunction()
+
+function(get_winml_test_scenario_src
+ winml_test_src_path
+ output_winml_test_scenario_src
+ output_winml_test_scenario_libs
+)
+ if (onnxruntime_USE_DML)
+ file(GLOB winml_test_scenario_src CONFIGURE_DEPENDS "${winml_test_src_path}/scenario/cppwinrt/*.cpp")
+ set(${output_winml_test_scenario_libs} "onnxruntime_providers_dml" PARENT_SCOPE)
+ else()
+ set(winml_test_scenario_src "${winml_test_src_path}/scenario/cppwinrt/scenariotestscppwinrt.cpp")
+ endif()
+ set(${output_winml_test_scenario_src} ${winml_test_scenario_src} PARENT_SCOPE)
+endfunction()
+
+function(get_winml_test_api_src
+ winml_test_src_path
+ output_winml_test_api_src
+)
+ file(GLOB winml_test_api_src CONFIGURE_DEPENDS "${winml_test_src_path}/api/*.cpp")
+ set(${output_winml_test_api_src} ${winml_test_api_src} PARENT_SCOPE)
+endfunction()
+
+file(GLOB winml_test_common_src CONFIGURE_DEPENDS "${WINML_TEST_SRC_DIR}/common/*.cpp")
+add_library(winml_test_common STATIC ${winml_test_common_src})
+add_dependencies(winml_test_common
+ onnx
+ winml_api
+ winml_dll
+)
+
+add_library(winml_google_test_lib STATIC ${WINML_TEST_SRC_DIR}/common/googletest/main.cpp)
+set_winml_target_properties(winml_google_test_lib)
+
+set_winml_target_properties(winml_test_common)
+get_winml_test_api_src(${WINML_TEST_SRC_DIR} winml_test_api_src)
+add_winml_test(
+ TARGET winml_test_api
+ SOURCES ${winml_test_api_src}
+ LIBS winml_test_common
+)
+target_compile_definitions(winml_test_api PRIVATE BUILD_GOOGLE_TEST)
+target_precompiled_header(winml_test_api testPch.h)
+
+get_winml_test_scenario_src(${WINML_TEST_SRC_DIR} winml_test_scenario_src winml_test_scenario_libs)
+add_winml_test(
+ TARGET winml_test_scenario
+ SOURCES ${winml_test_scenario_src}
+ LIBS winml_test_common delayimp.lib ${winml_test_scenario_libs}
+)
+target_precompiled_header(winml_test_scenario testPch.h)
+target_compile_definitions(winml_test_scenario PRIVATE BUILD_GOOGLE_TEST)
+set_target_properties(winml_test_scenario PROPERTIES LINK_FLAGS
+ "/DELAYLOAD:d2d1.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll"
+)
+
+# During build time, copy any modified collaterals.
+# configure_file(source destination COPYONLY), which configures CMake to copy the file whenever source is modified,
+# can't be used here because we don't know the destination during configure time (in multi-configuration generators,
+# such as VS, one can switch between Debug/Release builds in the same build tree, and the destination depends on the
+# build mode).
+function(add_winml_collateral source)
+ get_filename_component(source_directory ${source} DIRECTORY)
+ file(GLOB_RECURSE collaterals RELATIVE ${source_directory} ${source})
+ foreach(collateral ${collaterals})
+ set(collateral_path ${source_directory}/${collateral})
+ if(NOT IS_DIRECTORY ${collateral_path})
+ add_custom_command(TARGET winml_test_common
+ POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${collateral_path} "$/${collateral}")
+ endif()
+ endforeach()
+endfunction()
+
+add_winml_collateral("${WINML_TEST_SRC_DIR}/api/models/*.onnx")
+add_winml_collateral("${WINML_TEST_SRC_DIR}/collateral/images/*.png")
+add_winml_collateral("${WINML_TEST_SRC_DIR}/collateral/models/*.onnx")
+add_winml_collateral("${WINML_TEST_SRC_DIR}/common/testdata/squeezenet/*")
+add_winml_collateral("${WINML_TEST_SRC_DIR}/scenario/cppwinrt/*.onnx")
+add_winml_collateral("${WINML_TEST_SRC_DIR}/scenario/models/*.onnx")
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
index c4cc6a526da36..1a254099e2a43 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
@@ -55,6 +55,14 @@
Visible="false"
/>
+
+
+
+
#include
#include
+#ifdef _WIN32
+#include
+#endif
namespace onnxruntime {
namespace common {
@@ -75,6 +78,40 @@ inline const char* StatusCodeToString(StatusCode status) noexcept {
}
}
+#ifdef _WIN32
+inline HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
+ switch (status)
+ {
+ case StatusCode::OK:
+ return S_OK;
+ case StatusCode::FAIL:
+ return E_FAIL;
+ case StatusCode::INVALID_ARGUMENT:
+ return E_INVALIDARG;
+ case StatusCode::NO_SUCHFILE:
+ return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
+ case StatusCode::NO_MODEL:
+ return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
+ case StatusCode::ENGINE_ERROR:
+ return E_FAIL;
+ case StatusCode::RUNTIME_EXCEPTION:
+ return E_FAIL;
+ case StatusCode::INVALID_PROTOBUF:
+ return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
+ case StatusCode::MODEL_LOADED:
+ return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
+ case StatusCode::NOT_IMPLEMENTED:
+ return E_NOTIMPL;
+ case StatusCode::INVALID_GRAPH:
+ return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
+ case StatusCode::EP_FAIL:
+ return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
+ default:
+ return E_FAIL;
+ }
+}
+#endif
+
class Status {
public:
Status() noexcept = default;
diff --git a/include/onnxruntime/core/platform/windows/TraceLoggingConfig.h b/include/onnxruntime/core/platform/windows/TraceLoggingConfig.h
index d9aed52b87c40..77114ecdf8c0f 100644
--- a/include/onnxruntime/core/platform/windows/TraceLoggingConfig.h
+++ b/include/onnxruntime/core/platform/windows/TraceLoggingConfig.h
@@ -78,4 +78,4 @@ Module Name:
// TraceLoggingString(szUser, "UserName", "User's name", MICROSOFT_FIELDTAG_HASH_PII),
// ...);
#define MICROSOFT_FIELDTAG_DROP_PII 0x04000000
-#define MICROSOFT_FIELDTAG_HASH_PII 0x08000000
\ No newline at end of file
+#define MICROSOFT_FIELDTAG_HASH_PII 0x08000000
diff --git a/include/onnxruntime/core/platform/windows/readme.txt b/include/onnxruntime/core/platform/windows/readme.txt
new file mode 100644
index 0000000000000..f1a436fc200be
--- /dev/null
+++ b/include/onnxruntime/core/platform/windows/readme.txt
@@ -0,0 +1,2 @@
+copied from minkernel/published/internal/telemetry/open_source/TraceLoggingConfig.h
+this is the official open source edition for these configuration settings
\ No newline at end of file
diff --git a/include/onnxruntime/core/providers/winml/winml_provider_factory.h b/include/onnxruntime/core/providers/winml/winml_provider_factory.h
new file mode 100644
index 0000000000000..b08b42e310e41
--- /dev/null
+++ b/include/onnxruntime/core/providers/winml/winml_provider_factory.h
@@ -0,0 +1,9 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "onnxruntime_c_api.h"
+
+struct WinmlAdapterApi;
+typedef struct WinmlAdapterApi WinmlAdapterApi;
+
+ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION;
\ No newline at end of file
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 76b325a47169f..cca39a086280c 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -156,6 +156,8 @@ ORT_RUNTIME_CLASS(TypeInfo);
ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
ORT_RUNTIME_CLASS(SessionOptions);
ORT_RUNTIME_CLASS(CustomOpDomain);
+ORT_RUNTIME_CLASS(MapTypeInfo);
+ORT_RUNTIME_CLASS(SequenceTypeInfo);
// When passing in an allocator to any ORT function, be sure that the allocator object
// is not destroyed until the last allocated object using it is freed.
@@ -648,6 +650,66 @@ struct OrtApi {
ORT_CLASS_RELEASE(TensorTypeAndShapeInfo);
ORT_CLASS_RELEASE(SessionOptions);
ORT_CLASS_RELEASE(CustomOpDomain);
+
+ // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information)
+
+ // Version 2 - In development, feel free to add/remove/rearrange here
+
+ /**
+ * GetDenotationFromTypeInfo
+ * This api augments OrtTypeInfo to return denotations on the type.
+ * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor.
+ */
+ OrtStatus*(ORT_API_CALL* GetDenotationFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len)NO_EXCEPTION;
+
+ // OrtTypeInfo Casting methods
+
+ /**
+ * CastTypeInfoToMapTypeInfo
+ * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map.
+ * The OrtMapTypeInfo has additional information about the map's key type and value type.
+ * This is used by WinML to support model reflection APIs.
+ *
+ * Don't free the 'out' value
+ */
+ OrtStatus*(ORT_API_CALL* CastTypeInfoToMapTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out)NO_EXCEPTION;
+
+ /**
+ * CastTypeInfoToSequenceTypeInfo
+ * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence.
+ * The OrtSequenceTypeInfo has additional information about the sequence's element type.
+ * This is used by WinML to support model reflection APIs.
+ *
+ * Don't free the 'out' value
+ */
+ OrtStatus*(ORT_API_CALL* CastTypeInfoToSequenceTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out)NO_EXCEPTION;
+
+ // OrtMapTypeInfo Accessors
+
+ /**
+ * GetMapKeyType
+ * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType.
+ * This is used by WinML to support model reflection APIs.
+ */
+ OrtStatus*(ORT_API_CALL* GetMapKeyType)(_In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION;
+
+ /**
+ * GetMapValueType
+ * This api augments get the value type of a map.
+ */
+ OrtStatus*(ORT_API_CALL* GetMapValueType)(_In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION;
+
+ // OrtSequenceTypeInfo Accessors
+
+ /**
+ * GetSequenceElementType
+ * This api augments get the element type of a sequence.
+ * This is used by WinML to support model reflection APIs.
+ */
+ OrtStatus*(ORT_API_CALL* GetSequenceElementType)(_In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION;
+
+ ORT_CLASS_RELEASE(MapTypeInfo);
+ ORT_CLASS_RELEASE(SequenceTypeInfo);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index eb9e9394ddd72..a97a5d413f904 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -354,4 +354,4 @@ struct CustomOpBase : OrtCustomOp {
} // namespace Ort
-#include "onnxruntime_cxx_inline.h"
+#include "onnxruntime_cxx_inline.h"
\ No newline at end of file
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 99c1656c0a7bc..f6fb350171f01 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -553,4 +553,4 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
return out;
}
-} // namespace Ort
+} // namespace Ort
\ No newline at end of file
diff --git a/onnxruntime/core/framework/allocatormgr.cc b/onnxruntime/core/framework/allocatormgr.cc
index 6f2ccd3f2b034..3dec82e3a3d2c 100644
--- a/onnxruntime/core/framework/allocatormgr.cc
+++ b/onnxruntime/core/framework/allocatormgr.cc
@@ -29,9 +29,4 @@ AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, OrtDevice::De
return AllocatorPtr(std::move(device_allocator));
}
-DeviceAllocatorRegistry& DeviceAllocatorRegistry::Instance() {
- static DeviceAllocatorRegistry s_instance;
- return s_instance;
-}
-
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/allocatormgr.h b/onnxruntime/core/framework/allocatormgr.h
index e6824dba8b79f..0ccc30b695cad 100644
--- a/onnxruntime/core/framework/allocatormgr.h
+++ b/onnxruntime/core/framework/allocatormgr.h
@@ -18,25 +18,4 @@ struct DeviceAllocatorRegistrationInfo {
AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, OrtDevice::DeviceId device_id = 0);
-class DeviceAllocatorRegistry {
- public:
- void RegisterDeviceAllocator(std::string&& name, DeviceAllocatorFactory factory, size_t max_mem,
- OrtMemType mem_type = OrtMemTypeDefault) {
- DeviceAllocatorRegistrationInfo info({mem_type, factory, max_mem});
- device_allocator_registrations_.emplace(std::move(name), std::move(info));
- }
-
- const std::map& AllRegistrations() const {
- return device_allocator_registrations_;
- }
-
- static DeviceAllocatorRegistry& Instance();
-
- private:
- DeviceAllocatorRegistry() = default;
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeviceAllocatorRegistry);
-
- std::map device_allocator_registrations_;
-};
-
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
new file mode 100644
index 0000000000000..107cdbbed10c2
--- /dev/null
+++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "core/framework/onnxruntime_map_type_info.h"
+#include "core/framework/onnxruntime_typeinfo.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/session/ort_apis.h"
+#include "core/framework/error_code_helper.h"
+
+OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type) noexcept : map_key_type_(map_key_type), map_value_type_(map_value_type, &OrtApis::ReleaseTypeInfo) {
+}
+
+static ONNXTensorElementDataType
+ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) {
+ using TensorType = ONNX_NAMESPACE::TensorProto_DataType;
+ switch (data_type) {
+ case TensorType::TensorProto_DataType_BOOL: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; }
+ case TensorType::TensorProto_DataType_STRING: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; } // maps to c++ type std::string
+ case TensorType::TensorProto_DataType_FLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // maps to c type float
+ case TensorType::TensorProto_DataType_FLOAT: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }
+ case TensorType::TensorProto_DataType_DOUBLE: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; } // maps to c type double
+ case TensorType::TensorProto_DataType_INT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; } // maps to c type int8_t
+ case TensorType::TensorProto_DataType_INT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; } // maps to c type int16_t
+ case TensorType::TensorProto_DataType_INT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } // maps to c type int32_t
+ case TensorType::TensorProto_DataType_INT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } // maps to c type int64_t
+ case TensorType::TensorProto_DataType_UINT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; } // maps to c type uint8_t
+ case TensorType::TensorProto_DataType_UINT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; } // maps to c type uint16_t
+ case TensorType::TensorProto_DataType_UINT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; } // maps to c type uint32_t
+ case TensorType::TensorProto_DataType_UINT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; } // maps to c type uint64_t
+ case TensorType::TensorProto_DataType_COMPLEX64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; } // complex with float32 real and imaginary components
+ case TensorType::TensorProto_DataType_COMPLEX128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } // complex with float64 real and imaginary components
+ case TensorType::TensorProto_DataType_BFLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; } // Non-IEEE floating-point format based on IEEE754 single-precision
+ default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }
+ }
+}
+
+OrtStatus* OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtMapTypeInfo** out) {
+ auto value_case = type_proto->value_case();
+ if (value_case != ONNX_NAMESPACE::TypeProto::kMapType)
+ {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type map!");;
+ }
+
+ // Get the key type of the map
+ auto type_proto_map = type_proto->map_type();
+ auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type()));
+
+ // Get the value type of the map
+ OrtTypeInfo* map_value_type_info = nullptr;
+ if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_map.value_type(), &map_value_type_info))
+ {
+ return status;
+ }
+
+ *out = new OrtMapTypeInfo(map_key_type, map_value_type_info);
+ return nullptr;
+}
+
+OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) {
+ OrtTypeInfo* map_value_type_copy = nullptr;
+ if (auto status = map_value_type_->Clone(&map_value_type_copy))
+ {
+ return status;
+ }
+ *out = new OrtMapTypeInfo(map_key_type_, map_value_type_copy);
+ return nullptr;
+}
+
+// OrtMapTypeInfo Accessors
+ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) {
+ API_IMPL_BEGIN
+ *out = map_type_info->map_key_type_;
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) {
+ API_IMPL_BEGIN
+ return map_type_info->map_value_type_->Clone(out);
+ API_IMPL_END
+}
+
+ORT_API(void, OrtApis::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) {
+ delete ptr;
+}
\ No newline at end of file
diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h
new file mode 100644
index 0000000000000..46477d8f04fa7
--- /dev/null
+++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+
+#include "onnxruntime_c_api.h"
+
+#include
+
+namespace ONNX_NAMESPACE {
+class TypeProto;
+}
+
+struct OrtMapTypeInfo {
+ public:
+ ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
+ std::unique_ptr map_value_type_;
+
+ static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtMapTypeInfo** out);
+
+ OrtStatus* Clone(OrtMapTypeInfo** out);
+
+ private:
+ OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type)noexcept;
+ OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete;
+ OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete;
+
+};
diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc
new file mode 100644
index 0000000000000..a5ee0c9a63bb1
--- /dev/null
+++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc
@@ -0,0 +1,49 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "core/framework/onnxruntime_sequence_type_info.h"
+#include "core/framework/onnxruntime_typeinfo.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/session/ort_apis.h"
+#include "core/framework/error_code_helper.h"
+
+OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept :
+ sequence_key_type_(sequence_key_type, &OrtApis::ReleaseTypeInfo) {
+}
+
+OrtStatus* OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtSequenceTypeInfo** out) {
+ auto value_case = type_proto->value_case();
+ if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType)
+ {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type sequence!");;
+ }
+
+ auto type_proto_sequence = type_proto->sequence_type();
+ OrtTypeInfo* sequence_key_type_info = nullptr;
+ if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_sequence.elem_type(), &sequence_key_type_info))
+ {
+ return status;
+ }
+
+ *out = new OrtSequenceTypeInfo(sequence_key_type_info);
+ return nullptr;
+}
+
+OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) {
+ OrtTypeInfo* sequence_key_type_copy = nullptr;
+ if (auto status = sequence_key_type_->Clone(&sequence_key_type_copy))
+ {
+ return status;
+ }
+ *out = new OrtSequenceTypeInfo(sequence_key_type_copy);
+ return nullptr;
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) {
+ API_IMPL_BEGIN
+ return sequence_type_info->sequence_key_type_->Clone(out);
+ API_IMPL_END
+}
+
+ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) {
+ delete ptr;
+}
\ No newline at end of file
diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h
new file mode 100644
index 0000000000000..abb3503778b71
--- /dev/null
+++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+
+#include "onnxruntime_c_api.h"
+
+#include
+
+namespace ONNX_NAMESPACE {
+class TypeProto;
+}
+
+struct OrtSequenceTypeInfo {
+ public:
+ std::unique_ptr sequence_key_type_;
+
+ OrtStatus* Clone(OrtSequenceTypeInfo** out);
+
+ static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtSequenceTypeInfo** out);
+
+ private:
+ OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept;
+ OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete;
+ OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete;
+};
diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
index 080a3518048cb..42e03e802caf1 100644
--- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc
+++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
@@ -10,6 +10,11 @@
#include "core/framework/sparse_tensor.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/ort_apis.h"
+#include "core/framework/error_code_helper.h"
+
+#include "core/framework/tensor_type_and_shape.h"
+#include "core/framework/onnxruntime_map_type_info.h"
+#include "core/framework/onnxruntime_sequence_type_info.h"
using onnxruntime::BFloat16;
using onnxruntime::DataTypeImpl;
@@ -20,11 +25,27 @@ using onnxruntime::TensorShape;
namespace on = ONNX_NAMESPACE;
+OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) {
+}
+
OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) {
}
+OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtMapTypeInfo* map_type_info1) noexcept : type(type1), map_type_info(map_type_info1) {
+}
+
+OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtSequenceTypeInfo* sequence_type_info1) noexcept : type(type1), sequence_type_info(sequence_type_info1) {
+}
+
OrtTypeInfo::~OrtTypeInfo() {
OrtApis::ReleaseTensorTypeAndShapeInfo(data);
+
+ if (map_type_info) {
+ OrtApis::ReleaseMapTypeInfo(map_type_info);
+ }
+ if (sequence_type_info) {
+ OrtApis::ReleaseSequenceTypeInfo(sequence_type_info);
+ }
}
ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, ONNXType* out) {
@@ -37,6 +58,28 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType
return nullptr;
}
+ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) {
+ API_IMPL_BEGIN
+ *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr;
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) {
+ API_IMPL_BEGIN
+ *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr;
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) {
+ API_IMPL_BEGIN
+ *out = type_info->denotation.c_str();
+ *len = type_info->denotation.size();
+ return nullptr;
+ API_IMPL_END
+}
+
ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) {
delete ptr;
}
@@ -49,7 +92,7 @@ OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vectorIsTensorSequenceType()) {
- *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr);
+ *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE);
return nullptr;
}
@@ -92,16 +135,14 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) {
// Place Opaque first as tensors will be mostly handled above and maps and sequences are not common
switch (type_proto->value_case()) {
case on::TypeProto::kOpaqueType: {
- *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr);
+ *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE);
return nullptr;
}
case on::TypeProto::kMapType: {
- *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr);
- return nullptr;
+ return OrtTypeInfo::FromTypeProto(type_proto, out);
}
case on::TypeProto::kSequenceType: {
- *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr);
- return nullptr;
+ return OrtTypeInfo::FromTypeProto(type_proto, out);
}
// Real Tensor support
case on::TypeProto::kTensorType:
@@ -204,19 +245,39 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or
st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info);
}
if (st != nullptr) return st;
- *out = new OrtTypeInfo(ten_type, info);
+ auto type_info = new OrtTypeInfo(ten_type, info);
+ type_info->denotation = input->denotation();
+ *out = type_info;
return nullptr;
} break;
case on::TypeProto::kSequenceType: {
- *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr);
+ OrtSequenceTypeInfo* sequence_type_info = nullptr;
+
+ if (auto status = OrtSequenceTypeInfo::FromTypeProto(input, &sequence_type_info)) {
+ return status;
+ }
+
+ auto type_info = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info);
+ type_info->denotation = input->denotation();
+ *out = type_info;
return nullptr;
} break;
case on::TypeProto::kMapType: {
- *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr);
+ OrtMapTypeInfo* map_type_info = nullptr;
+
+ if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) {
+ return status;
+ }
+
+ auto type_info = new OrtTypeInfo(ONNX_TYPE_MAP, map_type_info);
+ type_info->denotation = input->denotation();
+ *out = type_info;
return nullptr;
} break;
case on::TypeProto::kOpaqueType: {
- *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr);
+ auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE);
+ type_info->denotation = input->denotation();
+ *out = type_info;
return nullptr;
} break;
case on::TypeProto::VALUE_NOT_SET:
@@ -227,3 +288,48 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or
}
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented");
}
+
+OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) {
+ switch (type) {
+ case ONNX_TYPE_TENSOR:
+ case ONNX_TYPE_SPARSETENSOR:
+ {
+ OrtTensorTypeAndShapeInfo* clone;
+ if (auto status = data->Clone(&clone)) {
+ return status;
+ }
+ *out = new OrtTypeInfo(type, clone);
+ (*out)->denotation = denotation;
+ return nullptr;
+ }
+ case ONNX_TYPE_SEQUENCE:
+ {
+ OrtSequenceTypeInfo* clone;
+ if (auto status = sequence_type_info->Clone(&clone)) {
+ return status;
+ }
+ *out = new OrtTypeInfo(type, clone);
+ (*out)->denotation = denotation;
+ return nullptr;
+ }
+ case ONNX_TYPE_MAP: {
+ OrtMapTypeInfo* clone;
+ if (auto status = map_type_info->Clone(&clone)) {
+ return status;
+ }
+ *out = new OrtTypeInfo(type, clone);
+ (*out)->denotation = denotation;
+ return nullptr;
+ }
+ case ONNX_TYPE_OPAQUE:
+ {
+ *out = new OrtTypeInfo(type);
+ (*out)->denotation = denotation;
+ return nullptr;
+ }
+ default:
+ // Not implemented
+ break;
+ }
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented");
+}
\ No newline at end of file
diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h
index d615840dcb501..3c256aa73d17d 100644
--- a/onnxruntime/core/framework/onnxruntime_typeinfo.h
+++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h
@@ -3,6 +3,7 @@
#pragma once
#include
+#include
#include "core/session/onnxruntime_c_api.h"
namespace onnxruntime {
@@ -14,6 +15,10 @@ namespace ONNX_NAMESPACE {
class TypeProto;
}
+// These types are only present in the winml adapter c api, so they are forward declared.
+struct OrtMapTypeInfo;
+struct OrtSequenceTypeInfo;
+
/**
* the equivalent of ONNX_NAMESPACE::TypeProto
* This class is mainly for the C API
@@ -21,19 +26,26 @@ class TypeProto;
struct OrtTypeInfo {
public:
ONNXType type = ONNX_TYPE_UNKNOWN;
+ std::string denotation;
~OrtTypeInfo();
//owned by this
OrtTensorTypeAndShapeInfo* data = nullptr;
+ OrtMapTypeInfo* map_type_info = nullptr;
+ OrtSequenceTypeInfo* sequence_type_info = nullptr;
OrtTypeInfo(const OrtTypeInfo& other) = delete;
OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete;
+ OrtStatus* Clone(OrtTypeInfo** out);
+
static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out);
static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out);
-
static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type);
private:
+ OrtTypeInfo(ONNXType type) noexcept;
OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept;
+ OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept;
+ OrtTypeInfo(ONNXType type, OrtSequenceTypeInfo* sequence_type_info) noexcept;
};
diff --git a/onnxruntime/core/framework/path_lib.cc b/onnxruntime/core/framework/path_lib.cc
index f2e526424c808..f34deb9025c7a 100644
--- a/onnxruntime/core/framework/path_lib.cc
+++ b/onnxruntime/core/framework/path_lib.cc
@@ -7,8 +7,11 @@
#include
#ifdef _WIN32
+#if defined(USE_PATHCCH_LIB)
+#include
+#pragma comment(lib, "PathCch.lib")
// Desktop apps need to support back to Windows 7, so we can't use PathCch.lib as it was added in Windows 8
-#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
+#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
#include
#pragma comment(lib, "Shlwapi.lib")
#else
@@ -24,7 +27,7 @@ namespace onnxruntime {
namespace {
Status RemoveFileSpec(PWSTR pszPath, size_t cchPath) {
assert(pszPath != nullptr && pszPath[0] != L'\0');
-#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
+#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) && !defined(USE_PATHCCH_LIB)
(void)cchPath;
for (PWSTR t = L"\0"; *t == L'\0'; t = PathRemoveBackslashW(pszPath))
;
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc
index 088043a159962..64bb11dbcbcfa 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.cc
+++ b/onnxruntime/core/framework/tensor_type_and_shape.cc
@@ -192,6 +192,11 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const st
return GetTensorShapeAndTypeHelper(type, shape, dim_params, out);
}
+OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out)
+{
+ return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out);
+}
+
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) {
API_IMPL_BEGIN
onnxruntime::MLDataType type = v->Type();
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h
index 28431a9d614cf..f781160cc6505 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.h
+++ b/onnxruntime/core/framework/tensor_type_and_shape.h
@@ -13,4 +13,6 @@ struct OrtTensorTypeAndShapeInfo {
OrtTensorTypeAndShapeInfo() = default;
OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete;
OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete;
+
+ OrtStatus* Clone(OrtTensorTypeAndShapeInfo** out);
};
diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc
index 44bcc577f3fd0..d5edeb7cb4b2f 100644
--- a/onnxruntime/core/graph/function.cc
+++ b/onnxruntime/core/graph/function.cc
@@ -378,7 +378,7 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
auto status = function_body_graph.Resolve();
ORT_ENFORCE(status.IsOK(), "Resolve subgraph failed:", status.ErrorMessage());
-}
+} // namespace onnxruntime
FunctionImpl::~FunctionImpl() = default;
diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc
index d8066bafba1aa..3cc756a12a71f 100644
--- a/onnxruntime/core/graph/model.cc
+++ b/onnxruntime/core/graph/model.cc
@@ -95,6 +95,10 @@ Model::Model(std::unique_ptr model_proto, const IOnnxRuntimeOpSchema
" specifies which version of the ONNX OperatorSet is being imported.");
}
+ if (!model_proto->has_ir_version() || model_proto->ir_version() > ONNX_NAMESPACE::Version::IR_VERSION) {
+ throw std::invalid_argument("Unknown model file format version.");
+ }
+
model_proto_ = std::move(model_proto);
for (auto& prop : model_proto_->metadata_props()) {
model_metadata_[prop.key()] = prop.value();
diff --git a/onnxruntime/core/graph/schema_registry.cc b/onnxruntime/core/graph/schema_registry.cc
index f0d4005c1503d..8127af7877c23 100644
--- a/onnxruntime/core/graph/schema_registry.cc
+++ b/onnxruntime/core/graph/schema_registry.cc
@@ -194,6 +194,17 @@ DomainToVersionMap SchemaRegistryManager::GetLatestOpsetVersions(bool is_onnx_on
return domain_version_map;
}
+static bool IsDomainVersionBeyondSupportedRange(
+ const std::string& domain,
+ const int op_set_version) {
+ // check the ONNX schema registry
+ auto& onnx_domain_version_map =
+ ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().Map();
+
+ auto it = onnx_domain_version_map.find(domain);
+ return it != onnx_domain_version_map.end() && op_set_version > it->second.second;
+}
+
// Return the schema with biggest version, which is not greater than specified
// in specified domain. The value of earliest_opset_where_unchanged
// is also set to the earliest version preceding op_set_version where the operator
@@ -238,10 +249,14 @@ void SchemaRegistryManager::GetSchemaAndHistory(
checked_registry_indices.push_back(index);
}
- // if not found in registered custom schema registry, search in ONNX schema registry
- *latest_schema = ONNX_NAMESPACE::OpSchemaRegistry::Schema(key, version, domain);
- if (*latest_schema != nullptr) {
- *earliest_opset_where_unchanged = (*latest_schema)->SinceVersion();
+ // Reject versions greater than what is actually supported.
+ *latest_schema = nullptr;
+ if (!IsDomainVersionBeyondSupportedRange(domain, version)) {
+ // if not found in registered custom schema registry, search in ONNX schema registry
+ *latest_schema = ONNX_NAMESPACE::OpSchemaRegistry::Schema(key, version, domain);
+ if (*latest_schema != nullptr) {
+ *earliest_opset_where_unchanged = (*latest_schema)->SinceVersion();
+ }
}
}
diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc
index 7c587a1b5d469..e6092693c6661 100644
--- a/onnxruntime/core/platform/telemetry.cc
+++ b/onnxruntime/core/platform/telemetry.cc
@@ -22,12 +22,22 @@ void Telemetry::DisableTelemetryEvents() const {
void Telemetry::LogProcessInfo() const {
}
+void Telemetry::LogSessionCreationStart() const {
+}
+
+void Telemetry::LogEvaluationStop() const {
+}
+
+void Telemetry::LogEvaluationStart() const {
+}
+
void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map& model_metadata,
- const std::string& loadedFrom, const std::vector& execution_provider_ids) const {
+ const std::string& loadedFrom, const std::vector& execution_provider_ids,
+ bool use_fp16) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(ir_version);
ORT_UNUSED_PARAMETER(model_producer_name);
@@ -38,6 +48,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
ORT_UNUSED_PARAMETER(model_metadata);
ORT_UNUSED_PARAMETER(loadedFrom);
ORT_UNUSED_PARAMETER(execution_provider_ids);
+ ORT_UNUSED_PARAMETER(use_fp16);
}
void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
@@ -55,5 +66,9 @@ void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_la
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
}
+void Telemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {
+ ORT_UNUSED_PARAMETER(adapterLuid);
+}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h
index a0fc42e045c49..a669c95eebd4a 100644
--- a/onnxruntime/core/platform/telemetry.h
+++ b/onnxruntime/core/platform/telemetry.h
@@ -10,6 +10,9 @@
#include "core/common/status.h"
#include "core/common/common.h"
+struct _LUID;
+using LUID = _LUID;
+
namespace onnxruntime {
/**
@@ -36,18 +39,27 @@ class Telemetry {
virtual void LogProcessInfo() const;
+ virtual void LogSessionCreationStart() const;
+
+ virtual void LogEvaluationStop() const;
+
+ virtual void LogEvaluationStart() const;
+
virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map& model_metadata,
- const std::string& loadedFrom, const std::vector& execution_provider_ids) const;
+ const std::string& loadedFrom, const std::vector& execution_provider_ids,
+ bool use_fp16) const;
virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const;
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const;
+ virtual void LogExecutionProviderEvent(LUID* adapterLuid) const;
+
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry);
};
diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc
index 46461f3d052ae..8c0beb13978d1 100644
--- a/onnxruntime/core/platform/windows/debug_alloc.cc
+++ b/onnxruntime/core/platform/windows/debug_alloc.cc
@@ -235,12 +235,8 @@ Memory_LeakCheck::~Memory_LeakCheck() {
_snprintf_s(buffer, _TRUNCATE, "%d bytes of memory leaked in %d allocations", leaked_bytes, leak_count);
string.append(buffer);
- // If we're being actively debugged, show a message box to get the dev's attention
- if (IsDebuggerPresent())
- MessageBoxA(nullptr, string.c_str(), "Warning", MB_OK | MB_ICONWARNING);
- else {
- // If we're on the command line (like on a build machine), output to the console and exit(-1)
- std::cout << "\n----- MEMORY LEAKS: " << string.c_str() << "\n";
+ std::cout << "\n----- MEMORY LEAKS: " << string.c_str() << "\n";
+ if (!IsDebuggerPresent()) {
exit(-1);
}
diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc
index c7a8bda42d1c2..23092bbf21880 100644
--- a/onnxruntime/core/platform/windows/telemetry.cc
+++ b/onnxruntime/core/platform/windows/telemetry.cc
@@ -96,12 +96,37 @@ void WindowsTelemetry::LogProcessInfo() const {
process_info_logged = true;
}
+void WindowsTelemetry::LogSessionCreationStart() const {
+ TraceLoggingWrite(telemetry_provider_handle,
+ "SessionCreationStart",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
+}
+
+void WindowsTelemetry::LogEvaluationStop() const {
+ TraceLoggingWrite(telemetry_provider_handle,
+ "EvaluationStop",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
+}
+
+void WindowsTelemetry::LogEvaluationStart() const {
+ TraceLoggingWrite(telemetry_provider_handle,
+ "EvaluationStart",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
+}
+
void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map& model_metadata,
- const std::string& loadedFrom, const std::vector& execution_provider_ids) const {
+ const std::string& loaded_from, const std::vector& execution_provider_ids,
+ bool use_fp16) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
@@ -156,10 +181,11 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),
+ TraceLoggingBool(use_fp16, "usefp16"),
TraceLoggingString(domain_to_verison_string.c_str(), "domainToVersionMap"),
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
- TraceLoggingString(loadedFrom.c_str(), "loadedFrom"),
+ TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
}
@@ -170,6 +196,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingWrite(telemetry_provider_handle,
"RuntimeError",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
@@ -198,4 +225,17 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"));
}
+void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "ExecutionProviderEvent",
+ TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ // Telemetry info
+ TraceLoggingUInt32(adapterLuid->LowPart, "adapterLuidLowPart"),
+ TraceLoggingUInt32(adapterLuid->HighPart, "adapterLuidHighPart"));
+}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h
index dd5da6205bcae..d34b26860308e 100644
--- a/onnxruntime/core/platform/windows/telemetry.h
+++ b/onnxruntime/core/platform/windows/telemetry.h
@@ -13,9 +13,7 @@ namespace onnxruntime {
* derives and implments a Telemetry provider on Windows
*/
class WindowsTelemetry : public Telemetry {
-
public:
-
// these are allowed to be created, WindowsEnv will create one
WindowsTelemetry();
~WindowsTelemetry();
@@ -25,18 +23,27 @@ class WindowsTelemetry : public Telemetry {
void LogProcessInfo() const override;
+ void LogSessionCreationStart() const override;
+
+ void LogEvaluationStop() const override;
+
+ void LogEvaluationStart() const override;
+
void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map& model_metadata,
- const std::string& loadedFrom, const std::vector& execution_provider_ids) const override;
-
+ const std::string& loadedFrom, const std::vector& execution_provider_ids,
+ bool use_fp16) const override;
+
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const override;
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override;
+ void LogExecutionProviderEvent(LUID* adapterLuid) const override;
+
private:
static OrtMutex mutex_;
static uint32_t global_register_count_;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
index c34dfbc2d93d6..0415976b58263 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
@@ -19,7 +19,7 @@ namespace onnxruntime
class Node;
}
-namespace winrt::Windows::AI::MachineLearning::implementation
+namespace Windows::AI::MachineLearning::Adapter
{
interface __declspec(uuid("5b19a18a-5ed5-4df2-a363-21b89380a698"))
IWinmlExecutionProvider : public IUnknown
@@ -65,6 +65,8 @@ namespace winrt::Windows::AI::MachineLearning::implementation
virtual void Close() = 0;
};
+ using MLOperatorTensorGetter = std::function(uint32_t index)>;
+
struct DmlOperatorParams
{
Microsoft::WRL::ComPtr op;
@@ -86,8 +88,6 @@ namespace winrt::Windows::AI::MachineLearning::implementation
bool allowHalfPrecisionComputation = false;
};
- using MLOperatorTensorGetter = std::function(uint32_t index)>;
-
using GraphNodeFactory = std::function;
}
-namespace winrt::Windows::AI::MachineLearning::implementation
+namespace Windows::AI::MachineLearning::Adapter
{
using namespace Microsoft::WRL;
@@ -110,4 +110,4 @@ class AbiCustomRegistry : public WRL::Base(_status.Code()))); \
} \
} while (0)
-
-namespace Dml
-{
- HRESULT MapLotusErrorToHRESULT(onnxruntime::common::Status status);
-}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 1c7c88c93aeff..84b298ca77dfc 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -29,7 +29,7 @@
#define ENABLE_GRAPH_COMPILATION
-using namespace winrt::Windows::AI::MachineLearning::implementation;
+using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml
{
@@ -129,6 +129,12 @@ namespace Dml
}
}
+// ORT release pipelines agent pools do not have 19H1 SDK installed which defines D3D_FEATURE_LEVEL_1_0_CORE.
+// Once ORT/WinML github project can be built with VS2019, we can update these pools to use install the 19H1 SDK
+// using the command line installer tool with VS2019
+// Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap
+#define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000)
+
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands)
: m_d3d12Device(d3d12Device),
m_dmlDevice(dmlDevice),
@@ -138,7 +144,7 @@ namespace Dml
D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {};
D3D_FEATURE_LEVEL featureLevelsList[] = {
- D3D_FEATURE_LEVEL_1_0_CORE,
+ D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE,
D3D_FEATURE_LEVEL_11_0,
D3D_FEATURE_LEVEL_11_1,
D3D_FEATURE_LEVEL_12_0,
@@ -153,7 +159,7 @@ namespace Dml
sizeof(featureLevels)
));
- m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE);
+ m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
@@ -674,7 +680,7 @@ namespace Dml
return m_areMetacommandsEnabled;
}
- std::shared_ptr
+ std::shared_ptr
ExecutionProviderImpl::GetInternalRegistrationInfoMap() const
{
return m_internalRegInfoMap;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index fd1709d8299b5..58f73f62bad3c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -9,13 +9,11 @@
#include
#include
-namespace WRL
-{
- template
- using Base = Microsoft::WRL::RuntimeClass<
- Microsoft::WRL::RuntimeClassFlags,
- TInterfaces...
- >;
+namespace WRL {
+template
+using Base = Microsoft::WRL::RuntimeClass<
+ Microsoft::WRL::RuntimeClassFlags,
+ TInterfaces...>;
}
using namespace Microsoft::WRL;
@@ -30,7 +28,7 @@ namespace Dml
class ExecutionProvider;
class ExecutionProviderImpl : public WRL::Base
+ Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider>
{
public:
explicit ExecutionProviderImpl::ExecutionProviderImpl(
@@ -158,7 +156,7 @@ namespace Dml
std::shared_ptr GetCpuInputAllocator();
std::shared_ptr GetCpuOutputAllocator();
- std::shared_ptr
+ std::shared_ptr
GetInternalRegistrationInfoMap() const;
private:
@@ -175,7 +173,7 @@ namespace Dml
std::shared_ptr m_cpuInputAllocator;
std::shared_ptr m_cpuOutputAllocator;
std::shared_ptr m_kernelRegistry;
- std::shared_ptr m_internalRegInfoMap;
+ std::shared_ptr m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;
bool m_closed = false;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp
index 0d5ab86b3582b..9391e191de86d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp
@@ -6,7 +6,7 @@
#include "MLOperatorAuthorImpl.h"
#include "FusedGraphKernel.h"
-using namespace winrt::Windows::AI::MachineLearning::implementation;
+using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml
{
@@ -170,7 +170,7 @@ namespace Dml
}
else
{
- std::tie(unpackedTensor, tensorByteSize) = winrt::Windows::AI::MachineLearning::implementation::UnpackTensor(initializer);
+ std::tie(unpackedTensor, tensorByteSize) = UnpackTensor(initializer);
tensorPtr = unpackedTensor.get();
}
@@ -726,7 +726,7 @@ namespace Dml
ComPtr m_compiledExecutionPlanOperator;
std::vector m_inputsUsed;
const void* m_executionHandle = nullptr;
- ComPtr m_winmlProvider;
+ ComPtr m_winmlProvider;
ComPtr m_provider;
EdgeShapes m_outputShapes;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 8b6b42c63a70b..622b7b96cb09c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -4,7 +4,7 @@
#include "precomp.h"
#include "GraphDescBuilder.h"
-using namespace winrt::Windows::AI::MachineLearning::implementation;
+using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml::GraphDescBuilder
{
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
index 68fc7cc9f513d..02319daab0ab9 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
@@ -9,14 +9,14 @@ namespace Dml
{
struct GraphNodeProperties
{
- std::shared_ptr
+ std::shared_ptr
internalRegInfo;
// These are currently passed from the partitioning step since the only DML operators current
// supporting graph nodes don't customize the order of edges or shapes, other than coercing
// dimension count. This will change as the supported set of operators as graph nodes increases.
- winrt::Windows::AI::MachineLearning::implementation::EdgeShapes inputShapes;
- winrt::Windows::AI::MachineLearning::implementation::EdgeShapes outputShapes;
+ Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes;
+ Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes;
};
namespace GraphDescBuilder
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp
index 1efbd3fd6b44b..e6ffb31d0084c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp
@@ -16,7 +16,7 @@
//#define PRINT_PARTITON_INFO
-using namespace winrt::Windows::AI::MachineLearning::implementation;
+using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml
{
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h
index 48e787736b65a..2c9dc497e1364 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h
@@ -43,7 +43,7 @@ namespace Dml
std::vector>
BuildPartitions(
const onnxruntime::GraphViewer& graph,
- const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap& internalRegInfoMap,
+ const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap,
const std::vector& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map& graphNodePropertyMap,
@@ -53,7 +53,7 @@ namespace Dml
std::vector>
PartitionGraph(
const onnxruntime::GraphViewer& graph,
- const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap& internalRegInfoMap,
+ const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap,
const std::vector& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
onnxruntime::KernelRegistry* registryForPartitionKernels,
@@ -64,7 +64,7 @@ namespace Dml
const onnxruntime::Node& node,
const onnxruntime::KernelRegistry& registry,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
- const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap& internalRegInfoMap,
+ const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap& internalRegInfoMap,
bool allow64BitInputThroughStrides,
_In_opt_ const std::unordered_map* nodeNameToPartitionMap // Only used when allow64BitInputThroughStrides is true
);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
index b959e0c930755..5df941e9ef887 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
@@ -13,7 +13,8 @@
using namespace Microsoft::WRL;
-namespace winrt::Windows::AI::MachineLearning::implementation {
+namespace Windows::AI::MachineLearning::Adapter
+{
size_t AttributeValue::ElementCount() const {
switch (type) {
@@ -91,8 +92,8 @@ bool IsAllocationInterface(const ::OrtMemoryInfo& info) {
// the ABI. The translation is determined by the provider and based on options with which the
// kernels are registered.
void TranslateAllocationDataToAbi(
- winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider* winmlProvider,
- bool isInternalOperator,
+ IWinmlExecutionProvider* winmlProvider,
+ bool isInternalOperator,
const ::OrtMemoryInfo& allocInfo,
IUnknown* allocation,
IUnknown** abiAllocation) {
@@ -1669,17 +1670,20 @@ EdgeShapes AbiOpKernel::GetInputShapes(onnxruntime::OpKernelContext* context) co
void AbiOpKernel::InferAndVerifyOutputSizes(
gsl::span requiredConstantCpuInputs,
- MLOperatorTensorGetter& constantInputGetter,
- const EdgeShapes* inputShapes,
- EdgeShapes& outputShapes) const {
- winrt::Windows::AI::MachineLearning::implementation::InferAndVerifyOutputSizes(
- Node(),
- m_defaultAttributes,
- m_shapeInferrer.Get(),
- requiredConstantCpuInputs,
- constantInputGetter,
- inputShapes,
- outputShapes);
+ MLOperatorTensorGetter& constantInputGetter,
+ const EdgeShapes* inputShapes,
+ EdgeShapes& outputShapes) const
+{
+ // call the non member function (below)
+ Windows::AI::MachineLearning::Adapter::InferAndVerifyOutputSizes(
+ Node(),
+ m_defaultAttributes,
+ m_shapeInferrer.Get(),
+ requiredConstantCpuInputs,
+ constantInputGetter,
+ inputShapes,
+ outputShapes
+ );
}
void InferAndVerifyOutputSizes(
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h
index d6fabb2f7287b..0168da24ef4ca 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h
@@ -21,7 +21,7 @@ namespace WRL
>;
}
-namespace winrt::Windows::AI::MachineLearning::implementation
+namespace Windows::AI::MachineLearning::Adapter
{
using namespace Microsoft::WRL;
@@ -380,7 +380,7 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper<
bool m_allowOutputShapeQuery = false;
bool m_internalOperator = false;
- ComPtr m_winmlProvider;
+ ComPtr m_winmlProvider;
const onnxruntime::OpKernelInfo* m_impl = nullptr;
@@ -435,7 +435,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper<
// For shape info, in addition to the info
const EdgeShapes* m_inferredOutputShapes = nullptr;
- ComPtr m_winmlProvider;
+ ComPtr m_winmlProvider;
bool m_internalOperator = false;
// The execution object returned through the ABI, which may vary according to kernel
@@ -477,7 +477,7 @@ class OpKernelContextWrapper : public WRL::Base, publi
std::vector> m_outputTensors;
const onnxruntime::IExecutionProvider* m_provider = nullptr;
- ComPtr m_winmlProvider;
+ ComPtr m_winmlProvider;
bool m_internalOperator = false;
// The execution object returned to the kernel may vary according to kernel execution options
@@ -542,7 +542,7 @@ class AbiOpKernel : public onnxruntime::OpKernel
mutable std::mutex m_mutex;
mutable EdgeShapes m_inferredOutputShapes;
- ComPtr m_winmlProvider;
+ ComPtr m_winmlProvider;
bool m_internalOperator = false;
std::vector m_requiredConstantCpuInputs;
@@ -640,4 +640,4 @@ bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputS
bool ContainsEmptyDimensions(const EdgeShapes& shapes);
std::tuple, size_t> UnpackTensor(const onnx::TensorProto& initializer);
-} // namespace winrt::Windows::AI::MachineLearning::implementation
+} // namespace Windows::AI::MachineLearning::Adapter
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
index 7f27291ede562..b64ae7dc751ae 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
@@ -134,7 +134,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling
}
};
-void QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported)
+void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported)
{
*isSupported = false;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp
index e167a89f0606e..0e9d0feb5a815 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp
@@ -71,7 +71,7 @@ class DmlOperatorSliceTemplate : public DmlOperatorSlice
}
};
-void QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported)
+void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported)
{
*isSupported = (context->GetInputCount() <= 4);
}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h
index 186608d78b586..45923d528dc05 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h
@@ -8,8 +8,8 @@ interface IMLOperatorKernel;
class MLOperatorKernelCreationContext;
// Forward declares an external creation function.
-#define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
-#define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported);
+#define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)
+#define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void CALLBACK Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported);
// Declares a callback creation function of the given operator class.
// This does not register it, just declares it for usage by registration later.
@@ -20,7 +20,7 @@ class MLOperatorKernelCreationContext;
// commas in them break the macro, and so they are stuffed into the VA_ARGS.
//
#define DML_OP_DEFINE_CREATION_FUNCTION(operatorName, ...)\
-extern void Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)\
+extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)\
{\
using T = __VA_ARGS__; \
THROW_IF_FAILED(MLOperatorKernel::CreateInstance(*kernelInfo, /*out*/ opKernel));\
diff --git a/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.cc b/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.cc
index 3d25aa1fe388e..9f63513e870e1 100644
--- a/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.cc
+++ b/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.cc
@@ -25,48 +25,17 @@
namespace GraphTransformerHelpers
{
- void RegisterGraphTransformers(onnxruntime::InferenceSession* lotusSession, bool registerLotusTransforms)
+ void RegisterGraphTransformers(onnxruntime::InferenceSession* lotusSession)
{
// Register Lotus graph transformers
+ // we were able to combine all of the winml/dml/ort work except for 2 transformers.
+ // these 2 are tracked by :
+ // Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT
//
- // TODO: Work out issues controlling graph optimization passes through ORT's optimization level
- // and rule list. In the meantime (and before new transformers are tested in Winml), passes
- // are registered explicitly, and the optimization level is set to default above (no optimization).
- //
- // Issues:
- // Why is UnsqueezeElimination not registered by name in ORT?
- // Why are level 2 (default) transformers not run before partitioning, which the DML XP requires?
- // Why are level2 transformers only enabled on the CPU provider in GenerateTransformers?
- // Why does name filtering only apply to rule based graph transformers?
- // Why is Matmul+Add not used when contrib ops are disabled?
-
- if (registerLotusTransforms)
- {
- lotusSession->RegisterGraphTransformer(std::move(std::make_unique()), onnxruntime::TransformerLevel::Level1);
- }
-
std::unique_ptr rule_transformer =
std::make_unique("WinmlRuleTransformer");
-
- if (registerLotusTransforms)
- {
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- rule_transformer->Register(std::make_unique());
- }
-
rule_transformer->Register(std::make_unique());
rule_transformer->Register(std::make_unique());
-
lotusSession->RegisterGraphTransformer(std::move(rule_transformer), onnxruntime::TransformerLevel::Level1);
-
- if (registerLotusTransforms)
- {
- lotusSession->RegisterGraphTransformer(std::move(std::make_unique()), onnxruntime::TransformerLevel::Level1);
- }
}
}
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.h b/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.h
index 169597c0d1341..bd9b1148cf0b0 100644
--- a/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.h
+++ b/onnxruntime/core/providers/dml/GraphTransformers/GraphTransformerHelpers.h
@@ -5,5 +5,5 @@
namespace GraphTransformerHelpers
{
- void RegisterGraphTransformers(onnxruntime::InferenceSession* lotusSession, bool registerLotusTransforms);
+ void RegisterGraphTransformers(onnxruntime::InferenceSession* lotusSession);
}
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
index b15955e0e533d..7fee23b8d0004 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h
@@ -734,7 +734,7 @@ class MLOperatorKernel : public Microsoft::WRL::RuntimeClass<
using MLOperatorTypeInferenceFunction = void (CALLBACK*)(IMLOperatorTypeInferenceContext*);
using MLOperatorShapeInferenceFunction = void (CALLBACK*)(IMLOperatorShapeInferenceContext*);
-using MLOperatorKernelCreateFn = void(*)(IMLOperatorKernelCreationContext*, IMLOperatorKernel**);
+using MLOperatorKernelCreateFn = void(CALLBACK*)(IMLOperatorKernelCreationContext*, IMLOperatorKernel**);
using MLOperatorSupportQueryFunction = void (CALLBACK*)(IMLOperatorSupportQueryContextPrivate*, bool*);
class MLOperatorShapeInferrer : public Microsoft::WRL::RuntimeClass<
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 03c003d2b6d0d..aa8486117fa97 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -6,15 +6,14 @@
#include "Common.h"
#include "Attributes.h"
#include "MLOperatorAuthorHelper.h"
+#include "core/common/common.h"
-namespace OperatorHelper
-{
+namespace OperatorHelper {
bool ContainsEmptyDimensions(gsl::span dimensions);
std::vector BroadcastTensorShape(
gsl::span inputShape0,
- gsl::span inputShape1
- );
+ gsl::span inputShape1);
// Find all the occurrences of a value, and return the array indices (in ascending order).
//
@@ -22,19 +21,16 @@ std::vector BroadcastTensorShape(
// value = 1
// output indices = {1,3,4}
#pragma optimize("", off)
-template
-void FindValueIndices(gsl::span values, T value, /*out*/ std::vector& indices)
-{
- indices.clear();
- for (size_t i = 0, valuesCount = values.size(); i < valuesCount; ++i)
- {
- // Work around compiler bug on x86 release by using data() rather than operator [] directly.
- // cl.exe 19.20.27412.4 for x86
- if (values.data()[i] == value)
- {
- indices.push_back(gsl::narrow_cast(i));
- }
+template
+void FindValueIndices(gsl::span values, T value, /*out*/ std::vector& indices) {
+ indices.clear();
+ for (size_t i = 0, valuesCount = values.size(); i < valuesCount; ++i) {
+ // Work around compiler bug on x86 release by using data() rather than operator [] directly.
+ // cl.exe 19.20.27412.4 for x86
+ if (values.data()[i] == value) {
+ indices.push_back(gsl::narrow_cast(i));
}
+ }
}
#pragma optimize("", on)
@@ -51,248 +47,224 @@ void HandleNegativeAxes(gsl::span onnxAxes, uint32_t dimCount);
// e.g. input values = {2,1,3,1,1,5}
// ellidable input indices = {1,3,4}
// output values = {2,3,5}
-template
-void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, /*inout*/ std::vector& values)
-{
- assert(std::is_sorted(indices.begin(), indices.end()));
-
- // Keep the last value at least, if all values would otherwise be removed.
- if (keepOneValue && !indices.empty() && indices.size() == values.size())
+template
+void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, /*inout*/ std::vector& values) {
+ assert(std::is_sorted(indices.begin(), indices.end()));
+
+ // Keep the last value at least, if all values would otherwise be removed.
+ if (keepOneValue && !indices.empty() && indices.size() == values.size()) {
+ indices = indices.first(indices.size() - 1);
+ }
+
+ auto indicesIterator = indices.begin();
+ auto indicesEnd = indices.end();
+ size_t oldValuesCount = values.size();
+ size_t newValuesCount = 0;
+ size_t nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++);
+
+ // For every value, either skip the entry, or copy it to the output.
+ for (size_t i = 0; i < oldValuesCount; ++i) {
+ if (i == nextIndex) // Skip and remove entry.
{
- indices = indices.first(indices.size() - 1);
- }
-
- auto indicesIterator = indices.begin();
- auto indicesEnd = indices.end();
- size_t oldValuesCount = values.size();
- size_t newValuesCount = 0;
- size_t nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++);
-
- // For every value, either skip the entry, or copy it to the output.
- for (size_t i = 0; i < oldValuesCount; ++i)
+ nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++);
+ } else // Keep and copy entry.
{
- if (i == nextIndex) // Skip and remove entry.
- {
- nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++);
- }
- else // Keep and copy entry.
- {
- values[newValuesCount++] = values[i];
- }
-
+ values[newValuesCount++] = values[i];
}
- values.resize(newValuesCount);
+ }
+ values.resize(newValuesCount);
}
int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p);
-class EdgeShapes
-{
-public:
- EdgeShapes() = default;
- EdgeShapes(const std::vector& dim){ m_shapes = dim; }
- EdgeShapes(const std::initializer_list& dim) { m_shapes.assign(dim.begin(), dim.end()); }
- EdgeShapes(const gsl::span dim) { m_shapes.assign(dim.begin(), dim.end()); }
-
- bool IsTensor() { return true; }
- bool IsUnused() { return m_shapes.empty(); }
-
- std::vector& GetShape() { return m_shapes; }
-private:
- std::vector m_shapes;
-};
+class EdgeShapes {
+ public:
+ EdgeShapes() = default;
+ EdgeShapes(const std::vector& dim) { m_shapes = dim; }
+ EdgeShapes(const std::initializer_list& dim) { m_shapes.assign(dim.begin(), dim.end()); }
+ EdgeShapes(const gsl::span dim) { m_shapes.assign(dim.begin(), dim.end()); }
-struct KernelArgs
-{
- // Initialize arrays up to NcdhwSpatialDimensionCount to avoid vector allocations,
- // but it's important to use .spatialDimensionCount when accessing them because
- // values beyond that may be bogus.
- uint32_t strides[NcdhwSpatialDimensionCount];
- uint32_t dilations[NcdhwSpatialDimensionCount];
- uint32_t windowSize[NcdhwSpatialDimensionCount];
- uint32_t startPadding[NcdhwSpatialDimensionCount];
- uint32_t endPadding[NcdhwSpatialDimensionCount];
- uint32_t outputPadding[NcdhwSpatialDimensionCount];
-
- KernelArgs(uint32_t spatialDimensionCount) :
- autoPad(false),
- autoPadSameUpper(false),
- spatialDimensionCount(spatialDimensionCount)
- {
- ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
- }
+ bool IsTensor() { return true; }
+ bool IsUnused() { return m_shapes.empty(); }
- void FillWithLeadingValues(gsl::span input, gsl::span output, uint32_t fillCount, uint32_t value)
- {
- // e.g.
- // input = [5,6,7,8]
- // fillcount = 2
- // value = 1
- // output = [1,1,5,6,7,8]
-
- const size_t inputCount = input.size();
- const size_t outputCount = output.size();
- const size_t clampedFillCount = std::min(size_t(fillCount), outputCount);
- const size_t copyCount = std::min(outputCount - fillCount, inputCount);
-
- std::fill_n(output.data(), fillCount, value);
- std::copy_n(input.data(), copyCount, output.data() + fillCount);
- }
+ std::vector& GetShape() { return m_shapes; }
- // Create a copy of an existing kernel args with a minimum dimension count,
- // filling the leading attribute values with 1's or 0's respectively.
- KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) :
- autoPad(kernelArgs.autoPad),
- autoPadSameUpper(kernelArgs.autoPadSameUpper),
- spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount))
- {
- ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
-
- uint32_t fillCount = (minimumDimensionCount > kernelArgs.spatialDimensionCount) ? minimumDimensionCount - kernelArgs.spatialDimensionCount : 0;
- FillWithLeadingValues(kernelArgs.strides, this->strides, fillCount, 1u);
- FillWithLeadingValues(kernelArgs.dilations, this->dilations, fillCount, 1u);
- FillWithLeadingValues(kernelArgs.windowSize, this->windowSize, fillCount, 1u);
- FillWithLeadingValues(kernelArgs.startPadding, this->startPadding, fillCount, 0u);
- FillWithLeadingValues(kernelArgs.endPadding, this->endPadding, fillCount, 0u);
- FillWithLeadingValues(kernelArgs.outputPadding, this->outputPadding, fillCount, 0u);
- }
+ private:
+ std::vector m_shapes;
+};
- // This is true if padding must be automatically computed based on input sizes.
- // ResolveAutoPadding must happen during Compute rather than initialization.
- // This is temporary until kernel initialization routine once Lotus can provide
- // sizes at operator initialization.
- bool autoPad;
- bool autoPadSameUpper;
- uint32_t spatialDimensionCount;
+struct KernelArgs {
+ // Initialize arrays up to NcdhwSpatialDimensionCount to avoid vector allocations,
+ // but it's important to use .spatialDimensionCount when accessing them because
+ // values beyond that may be bogus.
+ uint32_t strides[NcdhwSpatialDimensionCount];
+ uint32_t dilations[NcdhwSpatialDimensionCount];
+ uint32_t windowSize[NcdhwSpatialDimensionCount];
+ uint32_t startPadding[NcdhwSpatialDimensionCount];
+ uint32_t endPadding[NcdhwSpatialDimensionCount];
+ uint32_t outputPadding[NcdhwSpatialDimensionCount];
+
+ KernelArgs(uint32_t spatialDimensionCount) : autoPad(false),
+ autoPadSameUpper(false),
+ spatialDimensionCount(spatialDimensionCount) {
+ ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
+ }
+
+ void FillWithLeadingValues(gsl::span input, gsl::span output, uint32_t fillCount, uint32_t value) {
+ // e.g.
+ // input = [5,6,7,8]
+ // fillcount = 2
+ // value = 1
+ // output = [1,1,5,6,7,8]
+
+ const size_t inputCount = input.size();
+ const size_t outputCount = output.size();
+ const size_t clampedFillCount = std::min(size_t(fillCount), outputCount);
+ const size_t copyCount = std::min(outputCount - fillCount, inputCount);
+
+ std::fill_n(output.data(), fillCount, value);
+ std::copy_n(input.data(), copyCount, output.data() + fillCount);
+ }
+
+ // Create a copy of an existing kernel args with a minimum dimension count,
+ // filling the leading attribute values with 1's or 0's respectively.
+ KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad),
+ autoPadSameUpper(kernelArgs.autoPadSameUpper),
+ spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) {
+ ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
+
+ uint32_t fillCount = (minimumDimensionCount > kernelArgs.spatialDimensionCount) ? minimumDimensionCount - kernelArgs.spatialDimensionCount : 0;
+ FillWithLeadingValues(kernelArgs.strides, this->strides, fillCount, 1u);
+ FillWithLeadingValues(kernelArgs.dilations, this->dilations, fillCount, 1u);
+ FillWithLeadingValues(kernelArgs.windowSize, this->windowSize, fillCount, 1u);
+ FillWithLeadingValues(kernelArgs.startPadding, this->startPadding, fillCount, 0u);
+ FillWithLeadingValues(kernelArgs.endPadding, this->endPadding, fillCount, 0u);
+ FillWithLeadingValues(kernelArgs.outputPadding, this->outputPadding, fillCount, 0u);
+ }
+
+ // This is true if padding must be automatically computed based on input sizes.
+ // ResolveAutoPadding must happen during Compute rather than initialization.
+ // This is temporary until kernel initialization routine once Lotus can provide
+ // sizes at operator initialization.
+ bool autoPad;
+ bool autoPadSameUpper;
+ uint32_t spatialDimensionCount;
};
std::vector InitializeKernelOutputDimensions(
gsl::span inputDimensions,
- const KernelArgs& args
-);
+ const KernelArgs& args);
std::vector InitializeKernelOutputDimsTranspose(
gsl::span inputDimensions,
- const KernelArgs& args
-);
+ const KernelArgs& args);
KernelArgs InitializeGlobalKernel(gsl::span inputDimensions);
KernelArgs InitializeKernel(
const MLOperatorAttributes& kernelInfo,
uint32_t inputDimensionCount,
- gsl::span filterTensorShape
-);
+ gsl::span filterTensorShape);
void ResolveAutoPadding(
KernelArgs& args,
- gsl::span inputDimensions
-);
+ gsl::span inputDimensions);
+
+class GetOutputShapeAsInputShapeHelper {
+ public:
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
+ template
+ GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape){
+ ORT_UNUSED_PARAMETER(info);
+ ORT_UNUSED_PARAMETER(shape);
+ };
+
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+};
-class GetOutputShapeAsInputShapeHelper
-{
-public:
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
- template
- GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape) {};
+class GetBroadcastedOutputShapeHelper {
+ public:
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
+ template
+ GetBroadcastedOutputShapeHelper(const Info_t& info, const Shape_t& shape){};
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
-class GetBroadcastedOutputShapeHelper
-{
-public:
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
- template
- GetBroadcastedOutputShapeHelper(const Info_t& info, const Shape_t& shape) {};
+class RandomUniformHelperBase {
+ public:
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ template
+ RandomUniformHelperBase(const Info_t& info) {
+ m_high = info.GetOptionalAttribute(AttrName::High, 1.0f);
+ m_low = info.GetOptionalAttribute(AttrName::Low, 0.0f);
+
+ if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) {
+ m_seed = info.GetAttribute(AttrName::Seed);
+ } else {
+ m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count());
+ }
+ }
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ protected:
+ float m_high;
+ float m_low;
+ float m_seed;
};
-class RandomUniformHelperBase
-{
-public:
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- template
- RandomUniformHelperBase(const Info_t& info)
- {
- m_high = info.GetOptionalAttribute(AttrName::High, 1.0f);
- m_low = info.GetOptionalAttribute(AttrName::Low, 0.0f);
+class RandomUniformHelper : public RandomUniformHelperBase {
+ public:
+ template
+ RandomUniformHelper(const Info_t& info, const Shape_t& shape) : RandomUniformHelperBase(info) {
+ auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape);
+ ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing.");
+ m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end());
+ }
- if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float))
- {
- m_seed = info.GetAttribute(AttrName::Seed);
- }
- else
- {
- m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count());
- }
- }
-protected:
- float m_high;
- float m_low;
- float m_seed;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+ private:
+ // Returns an empty vector if the optional attribute is missing.
+ std::vector m_tensorShape;
};
-class RandomUniformHelper : public RandomUniformHelperBase
-{
-public:
- template
- RandomUniformHelper(const Info_t& info, const Shape_t& shape) : RandomUniformHelperBase(info)
- {
- auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape);
- ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing.");
- m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end());
+class RandomNormalHelperBase {
+ public:
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ template
+ RandomNormalHelperBase(const Info_t& info) {
+ m_mean = info.GetOptionalAttribute(AttrName::Mean, 0.0f);
+ m_scale = info.GetOptionalAttribute(AttrName::Scale, 1.0f);
+
+ if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) {
+ m_seed = info.GetAttribute(AttrName::Seed);
+ } else {
+ m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count());
}
+ }
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-private:
- // Returns an empty vector if the optional attribute is missing.
- std::vector m_tensorShape;
+ protected:
+ float m_mean;
+ float m_scale;
+ float m_seed;
};
-class RandomNormalHelperBase
-{
-public:
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- template
- RandomNormalHelperBase(const Info_t& info)
- {
- m_mean = info.GetOptionalAttribute(AttrName::Mean, 0.0f);
- m_scale = info.GetOptionalAttribute(AttrName::Scale, 1.0f);
-
- if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float))
- {
- m_seed = info.GetAttribute(AttrName::Seed);
- }
- else
- {
- m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count());
- }
- }
-protected:
- float m_mean;
- float m_scale;
- float m_seed;
-};
+class RandomNormalHelper : public RandomNormalHelperBase {
+ public:
+ template
+ RandomNormalHelper(const Info_t& info, const Shape_t& shape) : RandomNormalHelperBase(info) {
+ auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape);
+ ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing.");
+ m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end());
+ }
-class RandomNormalHelper : public RandomNormalHelperBase
-{
-public:
- template
- RandomNormalHelper(const Info_t& info, const Shape_t& shape) : RandomNormalHelperBase(info)
- {
- auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape);
- ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing.");
- m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end());
- }
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-private:
- // Returns an empty vector if the optional attribute is missing.
- std::vector m_tensorShape;
+ private:
+ // Returns an empty vector if the optional attribute is missing.
+ std::vector m_tensorShape;
};
class ConvolutionHelperBase
@@ -320,18 +292,17 @@ class ConvolutionHelperBase
}
}
- void ResolvingPadding(gsl::span inputDimensions);
+ void ResolvingPadding(gsl::span inputDimensions);
- const std::vector& GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
- {
- return m_outputShapes;
- }
-
- template
- void InitializeKernelAndShapes(const Shape_t& shapeInfo)
- {
- const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0);
- const std::vector filterDims = shapeInfo.GetInputTensorShape(1);
+ const std::vector& GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const {
+ ORT_UNUSED_PARAMETER(shapeInfo);
+ return m_outputShapes;
+ }
+
+ template
+ void InitializeKernelAndShapes(const Shape_t& shapeInfo) {
+ const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0);
+ const std::vector filterDims = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(
inputDimensions.size() >= 3 && inputDimensions.size() <= 5,
@@ -358,10 +329,10 @@ class ConvolutionHelperBase
);
}
- const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0);
- const std::vector filterDims = shapeInfo.GetInputTensorShape(1);
+ const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0);
+ const std::vector filterDims = shapeInfo.GetInputTensorShape(1);
- ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount, "Input dimensions must be >= 3");
+ ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount, "Input dimensions must be >= 3");
if (hasDynamicPads)
{
@@ -396,49 +367,45 @@ class ConvolutionHelperBase
assert(m_outputShapes[0].GetShape().size() > C);
m_outputShapes[0].GetShape()[C] = filterDims[C] * m_groupCount;
- if (!outputShape.empty())
- {
- // Start padding, end padding, and output padding are all ignored if output shape is set.
- std::fill(m_kernel.outputPadding, m_kernel.outputPadding + m_kernel.spatialDimensionCount, 0);
-
- if (outputShape.size() > 2)
- {
- ML_CHECK_VALID_ARGUMENT(outputShape[outputShape.size() - 3] == gsl::narrow_cast(m_outputShapes[0].GetShape()[C]), "Output channel must be equivalent to filter channel.");
- }
-
- for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i)
- {
- size_t outputIndex = outputShape.size() - m_kernel.spatialDimensionCount + i;
- ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast(inputDimensions[H + i]), "Output dimension cannot be smaller than input dimension.");
- m_outputShapes[0].GetShape()[H + i] = outputShape[outputIndex];
- }
-
- const int dimOffset = gsl::narrow_cast(inputDimensions.size() - m_kernel.spatialDimensionCount);
-
- for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i)
- {
- int stride = m_kernel.strides[i];
- int windowSize = m_kernel.windowSize[i];
-
- // Compute padding such that in reverse order, the logical input (m_outputShapes below) is fully defined
- // for a convolution over the logical output region (inputDimensions below).
- //
- // The padding required is the first windowSize element (for the first logical output element),
- // plus (logicalOutput - 1) steps of stride (the distance between each windowed set of logical
- // input elements), minus the actual logical input size.
- int paddings = gsl::narrow_cast((inputDimensions[i + dimOffset] - 1) * stride + windowSize - m_outputShapes[0].GetShape()[i + dimOffset]);
- paddings = std::max(0, paddings);
-
- m_kernel.startPadding[i] = m_kernel.autoPadSameUpper ? (paddings + 1) / 2 : paddings / 2;
- m_kernel.endPadding[i] = paddings - m_kernel.startPadding[i];
- }
- }
+ if (!outputShape.empty()) {
+ // Start padding, end padding, and output padding are all ignored if output shape is set.
+ std::fill(m_kernel.outputPadding, m_kernel.outputPadding + m_kernel.spatialDimensionCount, 0);
+
+ if (outputShape.size() > 2) {
+ ML_CHECK_VALID_ARGUMENT(outputShape[outputShape.size() - 3] == gsl::narrow_cast(m_outputShapes[0].GetShape()[C]), "Output channel must be equivalent to filter channel.");
+ }
+
+ for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i) {
+ size_t outputIndex = outputShape.size() - m_kernel.spatialDimensionCount + i;
+ ML_CHECK_VALID_ARGUMENT(outputShape[outputIndex] >= gsl::narrow_cast(inputDimensions[H + i]), "Output dimension cannot be smaller than input dimension.");
+ m_outputShapes[0].GetShape()[H + i] = outputShape[outputIndex];
+ }
+
+ const int dimOffset = gsl::narrow_cast(inputDimensions.size() - m_kernel.spatialDimensionCount);
+
+ for (size_t i = 0; i < m_kernel.spatialDimensionCount; ++i) {
+ int stride = m_kernel.strides[i];
+ int windowSize = m_kernel.windowSize[i];
+
+ // Compute padding such that in reverse order, the logical input (m_outputShapes below) is fully defined
+ // for a convolution over the logical output region (inputDimensions below).
+ //
+ // The padding required is the first windowSize element (for the first logical output element),
+ // plus (logicalOutput - 1) steps of stride (the distance between each windowed set of logical
+ // input elements), minus the actual logical input size.
+ int paddings = gsl::narrow_cast((inputDimensions[i + dimOffset] - 1) * stride + windowSize - m_outputShapes[0].GetShape()[i + dimOffset]);
+ paddings = std::max(0, paddings);
+
+ m_kernel.startPadding[i] = m_kernel.autoPadSameUpper ? (paddings + 1) / 2 : paddings / 2;
+ m_kernel.endPadding[i] = paddings - m_kernel.startPadding[i];
+ }
}
+ }
-protected:
- uint32_t m_groupCount;
- KernelArgs m_kernel;
- std::vector m_outputShapes;
+ protected:
+ uint32_t m_groupCount;
+ KernelArgs m_kernel;
+ std::vector m_outputShapes;
};
class ConvHelper : public ConvolutionHelperBase
@@ -470,6 +437,7 @@ class GemmHelper
template
GemmHelper(const Info_t& info, const Shape_t& shape)
{
+ ORT_UNUSED_PARAMETER(shape);
m_transA = info.GetOptionalAttribute(AttrName::TransA, 0);
m_transB = info.GetOptionalAttribute(AttrName::TransB, 0);
m_broadcast = info.GetOptionalAttribute(AttrName::Broadcast, 0);
@@ -477,61 +445,57 @@ class GemmHelper
m_beta = info.GetOptionalAttribute(AttrName::Beta, 0.0f);
}
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
- enum InputTensors { IN_A, IN_B, IN_C };
+ enum InputTensors { IN_A,
+ IN_B,
+ IN_C };
-protected:
- bool m_transA = false;
- bool m_transB = false;
- bool m_broadcast = false;
- float m_alpha = 0.0f;
- float m_beta = 0.0f;
+ protected:
+ bool m_transA = false;
+ bool m_transB = false;
+ bool m_broadcast = false;
+ float m_alpha = 0.0f;
+ float m_beta = 0.0f;
};
-class TransposeHelper
-{
-public:
- void Initialize(
- const MLOperatorAttributes& operatorAttributes,
- gsl::span inputDimensions
- );
+class TransposeHelper {
+ public:
+ void Initialize(
+ const MLOperatorAttributes& operatorAttributes,
+ gsl::span inputDimensions);
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
- template
- TransposeHelper(const Info_t& info, const Shape_t& shape)
- {
- Initialize(info, shape.GetInputTensorShape(0));
- }
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
+ template
+ TransposeHelper(const Info_t& info, const Shape_t& shape) {
+ Initialize(info, shape.GetInputTensorShape(0));
+ }
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-protected:
- std::vector m_permutations;
+ protected:
+ std::vector m_permutations;
};
-class SplitHelper
-{
-public:
- void Initialize(
- const MLOperatorAttributes& operatorAttributes,
- gsl::span inputDimensions
- );
+class SplitHelper {
+ public:
+ void Initialize(
+ const MLOperatorAttributes& operatorAttributes,
+ gsl::span inputDimensions);
- // Info_t is used to obtain attributes which will be used for calculating the output shape later.
- // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
- template
- SplitHelper(const Info_t& info, const Shape_t& shape)
- {
- Initialize(info, shape.GetInputTensorShape(0));
- }
+ // Info_t is used to obtain attributes which will be used for calculating the output shape later.
+ // Shape_t is used to obtain input shape which will be used for adjusting attribute value.
+ template
+ SplitHelper(const Info_t& info, const Shape_t& shape) {
+ Initialize(info, shape.GetInputTensorShape(0));
+ }
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-protected:
- int m_axis = 0;
- std::vector m_split;
+ protected:
+ int m_axis = 0;
+ std::vector m_split;
};
class SliceHelperBase
@@ -565,7 +529,7 @@ class SliceHelperBase
ends.push_back(gsl::narrow_cast(endsData[i]));
}
uint32_t inputCount = operatorInfo.GetInputCount();
- if (operatorInfo.GetInputCount() > 3)
+ if (inputCount > 3)
{
MLOperatorTensor axesTensor = operatorInfo.GetConstantInputTensor(3);
const std::vector& axesTensorDimensions = axesTensor.GetShape();
@@ -577,7 +541,7 @@ class SliceHelperBase
}
}
- if (operatorInfo.GetInputCount() > 4)
+ if (inputCount > 4)
{
MLOperatorTensor stepsTensor = operatorInfo.GetConstantInputTensor(4);
const std::vector& stepsTensorDimensions = stepsTensor.GetShape();
@@ -620,6 +584,9 @@ class SliceHelperBase
ReadIndexTensors(operatorInfo, starts, ends, axes, steps);
}
}
+
+ const uint32_t dimCount = gsl::narrow_cast(inputDimensions.size());
+ HandleNegativeAxes(/*inout*/ axes, dimCount);
ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size.");
ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty.");
@@ -668,13 +635,13 @@ class SliceHelperBase
Initialize(info, shape.GetInputTensorShape(0), opsetVersion);
}
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
-protected:
- std::vector m_outputDimensions;
- std::vector m_offsets;
- std::vector m_sizes;
- std::vector m_strides;
+ protected:
+ std::vector m_outputDimensions;
+ std::vector m_offsets;
+ std::vector m_sizes;
+ std::vector