From 34b5e65314bd8350d2fe3b8a3f48f69c143c99a7 Mon Sep 17 00:00:00 2001 From: smk2007 Date: Tue, 28 Jan 2020 15:43:28 -0800 Subject: [PATCH] User/sheilk/winml adapter c api (#2891) * Create winml adapter c api * fix build * make it build * move adapter into onnxruntime core/session * entry point not exported * minor changes * make model metadata work * make tests pass * implement all the model reflection apis on the adapter c abi * update the new ort interface to create a lotus ennvironment with a logging sink * start adding ort env * move all winml code into adapter folder/lib to isolate it * ensure a single logging manager at a time * start refactoring session * refactor session creation interface * add cpu and dml session option methods to adapter * finish session init * stub out interfaces in ort lib to perform similar mechanics of iinference session * enable profiling, and enable schema override * update session register graph transformers * turn back on custom registry for custom ops * Add sync api * add last c api stubs * should build... but all feature values are broken since this is in flight to moving all implementation details into ivalue * remove ep adapter header * Implement DML execution provider functions from adapter (#2846) * Implement DML execution provider functions from adapter * Use functions in OnnxruntimeEngine.cpp * make map/sequence type_infos freeable, and start implementing ivalue * make it build again * implement value methods * implement remaining methods * remove com adapter abi * check dml session * cache the allocator on ivalue * check if resource is cpu/gpu when access its mutable data * update tensor * mismatched parentheses * fix tensor base and binding obj * it evaluates tensors! sometimes... * minor fixes * enable gpu evals * wrapper all existing winml adapter apis with API_IMPL to try catch (#2854) * update winml... tensor strings are broken, need to template tensorbase to do different things for strings * make tensor strings work with 2 copies in/2 copies out * Fix tensor string and allocator bug * make maps work again... needs some fixes still * Make it build! * enable map inputs * map outputs * unbound outputs for sequences and maps * User/xianz/merge windowsai (#2883) * Packaging pipeline changes for VS 2019 (#2711) * Tiny fix to codegen * Simplify cache implementation and avoid static variables that may carry over between models * Extend DML kernels (#2641) * Additional DML operators * Check unsupported attributes and inputs * Address PR comments * Add kernel capability function used for partitioning, and re-enable stride-based int64 support based on value range * Fix test failures * Build fix * PR comments * Update Nuphar tutorial notebook (#2721) 1. Reflect int8 GEMV improvements for multi-threading from #2696 2. Add notes on multi-threading control using OpenMP 3. Add samples of running multi-isa AOT, and show int8 GEMM differences between AVX and AVX2 4. Add rnn_benchmark example to resolve #1993 * Add schema for new Qops (#2611) * Add schema for new Qops * adding shape inference + qlinearaveragepool * plus review comments * plus review comments * updates per review comments * plus review comments * [server] Add supposed for model_name and model_version as cli parameter (#2708) * remove 64bit warning message from python validation. (#2727) * MLAS: ARM64 build fix (#2734) fix bad usage of vreinterpret to cast vector element types * Fix broken python docs links (#2740) * Fix build on Mac OS (#2731) mac os ld doesn't support --while-archive, correct option is -all_load * fix ngraph wheel (#2737) * fix ngraph wheel 1.1.0 onnxruntime_ngraph wheel doesn't work * remove libdnnl.so in nGraph Libs * make it easy to compare * Split onnxruntime server to a separated folder (#2744) * Fix build for Python 3.8 (#2747) * Fix build for Python 3.8 * Update protobuf to 3.11.2 (#1928) Update protobuf to 3.11.2 (#1928) * Change default optimization level to All (from Basic) (#2745) * change default optimization level to All (from Basic) * fix test * fix c# test * Update numpy to 1.18 (#2758) * Update numpy to 1.18 * Pipeline changes for python 3.8 (#2753) 1. Pipeline changes for python 3.8 2. Fix a regression in setup.py which was just introduced in the previous commit. Please notice, we still haven't made python 3.8 + Windows + CUDA work. * Add basic stacktrace output for posix debug builds. (#2749) * [NupharEP] fix a race condition when multiple sessions running different models concurrently (#2772) * Revert "Change default optimization level to All (from Basic) (#2745)" This reverts commit 56bb503c2f26474b6613bcb2a198691a11dcef38. * Fix typo in error message (#2736) * Rename MKL-DNN to DNNL to fix broken link (#2730) * Fix nightly build version number issue * Pass BUILD_BUILDNUMBER to linux docker * Disable featurizers in python packages * Import more featurizers (#2781) Make kernels non-template. Add input constraint for learnt data. Add min_max_scalar_transformer, robust_scalar_transformer, inputation_marker_transfomer, label_encoder_transformer, missing_dummies_transformer along with tests. Advance Featurizers library commit. * Implement a more stable softmax (#2715) * Implement a more stable SoftMax e^x is represented as infinity if x is large enough, like 100.f. Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. A math transform as below is leveraged to get a stable softmax: e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) And for convenience, force max to 0.f if all xi are negative * Contributing: Fix a typo (#2784) * ACL EP GEMM improvements (#2780) When it is posible we use a fully connected layer instead of the gemm implementation. This will let the library use the best implementation based on the input data. * ACL EP convolution improvements (#2774) Added the optimized implementation for depthwise convolution for both ACL v19.02 and ACL 19.05. Also the pointwise convolution seems to be more optimal in the CPU implementation so we opted for that instead. * Add script for release Nuget validation (#2719) * Initial commit * Nits * Disable a test temporarily * Change working directory * Test * Add download python step * Test update * More changes * Fix space issue * Fix * Verify nuget signing * Fix * Spaces * PR feedback * Nit * Fix * Fix * Remove temporary changes * add uint8 support to where op (#2792) * Improve bert optimization script: (#2712) (1) Move input int64=>int32 conversion to embed layer fusion. (2) Output epsilon attribute for LayerNormalization fusion. * add session creation time cost. (#2798) * ML.NET team needs featurizers within a package (#2789) Add auto ml featurizers to Windows, MacOS as well as to GPU packaging-pipelines. * Initialize max of softmax with lowest of float (#2786) * MLAS: update SGEMM threading parameters (#2808) * add interface to copy batch tensors. (#2807) * add interface to copy batch tensors. * onnxruntime * speed up Windows TRT CI (#2811) * don't run cuda tests if building with tensorrt * remove unnecessary build options for win trt ci * refactor win gpu tensorrt ci yml * --numpy_version=1.17 * update * update * azcopy and cuda path * Update test data (#2356) * Add timeseries imputer transformer featurizer kernel (#2813) Make kernels non-template. Add input constraint for learnt data. Fixup tests. Add two more featurizers along with tests. Tests fail. min_max_scalar_transformer robust_scalar_transformer Fix tests serialized stream by prepending version bytes. Add inputation_marker_transfomer and the test. Fix up float/double type designations. Added label_encoder_transformer along with a test. string_throw case is broken at the momement. Fix labelencodertransfomer_test.cc string_throw case Rename maxabsscalertransformer_test.cc Add MissingDummiesTransformer along with the test. Update manifest. Add TimeSeriesImputerTransformer definition, implementation and tests * Fix memory leak in TRT (#2815) * fix memory leak issue * revert EP_FAIL on enueueV2 * Add manifest missing comma * Run static code analyzer on most of our code (#2817) * Scneario Test : Build Google Test and Taef Test based on preprocessor definition (#2809) * Add winml macro wrappers on top of google test macros * change test methods to disabled * Add custom winml macros for both taef and google tests * PR comments * update quantization doc (#2783) * update documentation for quantization script * plus some spell corrections * Filter CPU case for IsFloat16Supported (#2802) * update default optimization level + fix gemm_activation fusion (#2791) * update defualt optimization level + fix gemm_activation fusion * fix typo * add unit test and incorporate review comments * fix test comment * Fix dnnl wheel package name (#2823) * Append '-dnnl' to whl package name when --use_dnnl * Update build.py * Update Ubuntu & TensorRT version in README (#2820) Dockerfile.tensorrt is using nvcr.io/nvidia/tensorrt:19.09-py3 as base Image, update Ubuntu and TensorRT version according to https://docs.nvidia.com/deeplearning/sdk/tensorrt-container-release-notes/rel_19-09.html#rel_19-09 * Merge fixes * Add OneHotEncoder and HashOneHotEncoder kernels. (#2830) Add defs and imlementation for OneHotEncoders, adjuist date_time_transformer kernel and test. Add OneHotEncoder kernel test. Add HashOneHotVectorizerTransformer unit test. This does not link due to multiple definitions of functions that are included into header from a CPP file. * Upgrade gtest to the latest version (#2827) WinML would like to update the googletest submodule. They want some newer features (namely GTEST_SKIP to skip tests programmatically and be able to skip entire fixtures easily) and would need to update the submodule version. However, because the new version of code hit a bug in gcc, even though the bug is already fixed in the latest gcc but we're using gcc 4.8.x and it won't get patched for the bug, so we have to do a compromise, change our code a little bit to make it work. The gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 * Add support for int64_t for topk CPU. Fixes github issue #2806. (#2833) * Ignore allocator type in ExecutionProviders allocator map. Make default initialization of OrtMemoryInfo more clearly invalid. (#2768) * Remove allocator type from the key comparison in ExecutionProviders. Remove usage of DummyArena as it's no longer necessary. * Fix x86 tests where arena allocator is disabled. Make initialization of OrtMemoryInfo clearer by adding Invalid enum value. * Make OrtValueNameIdxMap::MaxIdx more intuitive. * Convert ExternalProject Featurizers into git submodule (#2834) Add git submodule for Featurizer library. Update cmake to build for git submodule. * add domain check for nodes + update documentation (#2831) * Fix cgmanifest.json generating script (#2770) * Fix protobuf submodule name * Workaround pygit2 bug * User/orilevari/32bit comparison warning (#2800) * use correct type for for loop * explicitly specify void for parameters of OrtGetApiBase because the function is defined in c, so when the function is just (), it is interpreted as having an unknown number of parameters. This was causing compiler warning C4276. * CMake cross-generator fixes (#2790) * Fix compilation w/ non-VS CMake generators * Fix custom WINMD target in Ninja * Remove usage of msbuild .targets file * Fix linking using DML in Ninja * Automate SDK kit version choice * Cleanup DML package install * Fix SDK version detection * Fix comment * Revert unittest linkage changes * Fix latest SDK detection * Don't link to non-uapcore libraries * Remove MessageBoxA reference and unused link libs * Fix Linux CUDA nuget packaging pipeline break * Refactor WinMLAPI Tests to build both google and taef test based on preprocessor definition (#2829) * Add winml macro wrappers on top of google test macros * change test methods to disabled * Add custom winml macros for both taef and google tests * PR comments * Refactor winml api tests * Move additional gtest specific macro definition into googleTestMacros.h * Fix test build break since winml_lib_api needs to be statically linked to tests since winmlp::learningmodeldevice::iscpu() is being used in devicehelpers.cpp (#2837) * Enforce WINML_TEST_CLASS_BEGIN_* matches w/ a WINML_TEST_CLASS_END (#2841) * update optimization doc for BERT related fusions (#2819) * Add bert related transformers to doc * Add execution provider and comment for bert optimizations * Add comment about accuracy impact of approximation * Fix warnings that cause build to fail * MLAS: enable threading for quantized GEMMs (#2844) * Fix test warnings and delayload linking (#2843) * Ortmemoryinfo struct changed * mark the camera scenario test as edgecore because it uses d3d11 (#2852) * User/orilevari/pipeline fi breaks (#2853) * remove conflicting artifact names. Decided to stop using drop-nuget-cuda since this may have implications on other dependent pipelines. * change job name in gpu.yml back to Windows_CI_GPU_CUDA_Dev * Remove internal libs from tests (#2864) * Support custom DML in onnxruntime_providers.cmake (#2867) * remove old winmladapter cpp Co-authored-by: Changming Sun Co-authored-by: KeDengMS Co-authored-by: Jeff <38966965+jeffbloo@users.noreply.github.com> Co-authored-by: Ashwini Khade Co-authored-by: Andrey Co-authored-by: George Wu Co-authored-by: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Co-authored-by: Faith Xu Co-authored-by: zhanyi-ms Co-authored-by: Changyoung Koh Co-authored-by: Scott McKay Co-authored-by: Takeshi Watanabe Co-authored-by: Dmitri Smirnov Co-authored-by: Yufeng Li Co-authored-by: Maher Jendoubi Co-authored-by: Andrews548 <32704142+Andrews548@users.noreply.github.com> Co-authored-by: Hariharan Seshadri Co-authored-by: Nathan <7902510+ybrnathan@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Ke Zhang Co-authored-by: stevenlix <38092805+stevenlix@users.noreply.github.com> Co-authored-by: Ryan Lai Co-authored-by: Ori Levari Co-authored-by: Yingge WAN Co-authored-by: Qing Co-authored-by: Pranav Sharma Co-authored-by: Tiago Koji Castro Shibata * move sequence implementation into ort lib... still commented out... need to turn back on... * begin sequence implementation * make maps and sequences work * fix broken tests * remove dead code * misc cleanup * CR feedback * User/xianz/winml adapter c api (#2869) * wrapper all existing winml adapter apis with API_IMPL to try catch * Return HR or Throw for WinML adapter APIs if failed * undo macro wrapper for two places * Wrap error macros around ort apis, too. * address CR feedback #2 * add more api throw/return macros * Revert changes no longer needed * revert changes to cxx api * format winml lib.ort and winml adapter * remove static pheonix singleton Co-authored-by: Ryan Lai Co-authored-by: Xiang Zhang Co-authored-by: Changming Sun Co-authored-by: KeDengMS Co-authored-by: Jeff <38966965+jeffbloo@users.noreply.github.com> Co-authored-by: Ashwini Khade Co-authored-by: Andrey Co-authored-by: George Wu Co-authored-by: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Co-authored-by: Faith Xu Co-authored-by: zhanyi-ms Co-authored-by: Changyoung Koh Co-authored-by: Scott McKay Co-authored-by: Takeshi Watanabe Co-authored-by: Dmitri Smirnov Co-authored-by: Yufeng Li Co-authored-by: Maher Jendoubi Co-authored-by: Andrews548 <32704142+Andrews548@users.noreply.github.com> Co-authored-by: Hariharan Seshadri Co-authored-by: Nathan <7902510+ybrnathan@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Ke Zhang Co-authored-by: stevenlix <38092805+stevenlix@users.noreply.github.com> Co-authored-by: Ori Levari Co-authored-by: Yingge WAN Co-authored-by: Qing Co-authored-by: Pranav Sharma Co-authored-by: Tiago Koji Castro Shibata --- cmake/winml.cmake | 123 +- .../providers/winml/winml_provider_factory.h | 13 +- .../core/session/onnxruntime_c_api.h | 2 + .../core/session/onnxruntime_cxx_api.h | 25 +- .../core/session/onnxruntime_cxx_inline.h | 56 +- onnxruntime/core/framework/allocatormgr.cc | 5 - onnxruntime/core/framework/allocatormgr.h | 21 - .../framework/onnxruntime_map_type_info.cc | 84 ++ .../framework/onnxruntime_map_type_info.h | 27 + .../onnxruntime_sequence_type_info.cc | 49 + .../onnxruntime_sequence_type_info.h | 25 + .../core/framework/onnxruntime_typeinfo.cc | 128 +- .../core/framework/onnxruntime_typeinfo.h | 14 +- .../core/framework/tensor_type_and_shape.cc | 5 + .../core/framework/tensor_type_and_shape.h | 2 + .../providers/dml/dml_provider_factory.cc | 21 +- onnxruntime/core/session/onnxruntime_c_api.cc | 114 +- onnxruntime/core/session/onnxruntime_env.cc | 90 ++ onnxruntime/core/session/onnxruntime_env.h | 65 + onnxruntime/core/session/ort_apis.h | 14 + .../azure-pipelines/win-ci-pipeline.yml | 12 +- .../azure-pipelines/win-gpu-ci-pipeline.yml | 3 +- winml/adapter/CpuOrtSessionBuilder.cpp | 102 -- winml/adapter/CpuOrtSessionBuilder.h | 30 - winml/adapter/CustomRegistryHelper.h | 30 - winml/adapter/DmlOrtSessionBuilder.cpp | 167 --- winml/adapter/DmlOrtSessionBuilder.h | 34 - winml/adapter/FeatureDescriptorFactory.h | 21 - winml/adapter/LotusEnvironment.cpp | 117 -- winml/adapter/LotusEnvironment.h | 90 -- winml/adapter/WinMLAdapter.cpp | 759 ---------- winml/adapter/WinMLAdapter.h | 202 --- winml/adapter/WinMLAdapterErrors.h | 41 - winml/adapter/ZeroCopyInputStreamWrapper.cpp | 77 - winml/adapter/ZeroCopyInputStreamWrapper.h | 43 - ...yImpl.cpp => abi_custom_registry_impl.cpp} | 2 +- ...istryImpl.h => abi_custom_registry_impl.h} | 4 +- winml/adapter/winml_adapter_apis.h | 85 ++ winml/adapter/winml_adapter_c_api.cpp | 105 ++ winml/adapter/winml_adapter_c_api.h | 469 ++++++ winml/adapter/winml_adapter_dml.cpp | 159 +++ winml/adapter/winml_adapter_environment.cpp | 84 ++ .../winml_adapter_execution_provider.cpp | 90 ++ winml/adapter/winml_adapter_model.cpp | 429 ++++++ winml/adapter/winml_adapter_model.h | 29 + winml/adapter/winml_adapter_session.cpp | 240 ++++ winml/api/Windows.AI.MachineLearning.idl | 62 +- winml/dll/module.cpp | 22 +- .../Api.Ort/OnnxruntimeCpuSessionBuilder.cpp | 89 ++ .../Api.Ort/OnnxruntimeCpuSessionBuilder.h | 32 + .../OnnxruntimeDescriptorConverter.cpp} | 407 +++--- .../Api.Ort/OnnxruntimeDescriptorConverter.h | 33 + .../Api.Ort/OnnxruntimeDmlSessionBuilder.cpp | 105 ++ .../Api.Ort/OnnxruntimeDmlSessionBuilder.h | 34 + winml/lib/Api.Ort/OnnxruntimeEngine.cpp | 1265 +++++++++++++++++ winml/lib/Api.Ort/OnnxruntimeEngine.h | 143 ++ .../lib/Api.Ort/OnnxruntimeEngineBuilder.cpp | 72 + winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h | 33 + winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp | 148 ++ winml/lib/Api.Ort/OnnxruntimeEnvironment.h | 26 + winml/lib/Api.Ort/OnnxruntimeErrors.h | 65 + winml/lib/Api.Ort/OnnxruntimeModel.cpp | 220 +++ winml/lib/Api.Ort/OnnxruntimeModel.h | 80 ++ winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h | 23 + winml/lib/Api.Ort/inc/OnnxruntimeProvider.h | 8 + winml/lib/Api.Ort/pch.h | 20 + winml/lib/Api/FeatureValues.h | 4 +- winml/lib/Api/ImageFeatureDescriptor.cpp | 24 +- winml/lib/Api/ImageFeatureDescriptor.h | 21 +- winml/lib/Api/ImageFeatureValue.cpp | 177 +-- winml/lib/Api/ImageFeatureValue.h | 8 +- winml/lib/Api/LearningModel.cpp | 122 +- winml/lib/Api/LearningModel.h | 24 +- winml/lib/Api/LearningModelBinding.cpp | 319 ++--- winml/lib/Api/LearningModelBinding.h | 39 +- winml/lib/Api/LearningModelSession.cpp | 243 ++-- winml/lib/Api/LearningModelSession.h | 26 +- winml/lib/Api/MapFeatureDescriptor.cpp | 13 - winml/lib/Api/MapFeatureDescriptor.h | 17 +- winml/lib/Api/SequenceFeatureDescriptor.cpp | 10 - winml/lib/Api/SequenceFeatureDescriptor.h | 13 +- winml/lib/Api/TensorFeatureDescriptor.cpp | 21 +- winml/lib/Api/TensorFeatureDescriptor.h | 26 +- winml/lib/Api/impl/MapBase.h | 133 +- winml/lib/Api/impl/SequenceBase.h | 208 +-- winml/lib/Api/impl/Tensor.h | 39 +- winml/lib/Api/impl/TensorBase.h | 212 +-- winml/lib/Api/impl/TensorBuffer.h | 24 +- winml/lib/Api/impl/TensorKindFrom.h | 16 +- .../Api/impl/TensorMemoryBufferReference.h | 62 +- .../lib/Api/inc/ILotusValueProviderPrivate.h | 6 +- winml/lib/Common/inc/PheonixSingleton.h | 6 +- winml/lib/Common/inc/iengine.h | 181 +++ winml/lib/Common/inc/onnx.h | 3 - .../test/api/LearningModelBindingAPITest.cpp | 2 + 95 files changed, 5707 insertions(+), 3486 deletions(-) create mode 100644 onnxruntime/core/framework/onnxruntime_map_type_info.cc create mode 100644 onnxruntime/core/framework/onnxruntime_map_type_info.h create mode 100644 onnxruntime/core/framework/onnxruntime_sequence_type_info.cc create mode 100644 onnxruntime/core/framework/onnxruntime_sequence_type_info.h create mode 100644 onnxruntime/core/session/onnxruntime_env.cc create mode 100644 onnxruntime/core/session/onnxruntime_env.h delete mode 100644 winml/adapter/CpuOrtSessionBuilder.cpp delete mode 100644 winml/adapter/CpuOrtSessionBuilder.h delete mode 100644 winml/adapter/CustomRegistryHelper.h delete mode 100644 winml/adapter/DmlOrtSessionBuilder.cpp delete mode 100644 winml/adapter/DmlOrtSessionBuilder.h delete mode 100644 winml/adapter/FeatureDescriptorFactory.h delete mode 100644 winml/adapter/LotusEnvironment.cpp delete mode 100644 winml/adapter/LotusEnvironment.h delete mode 100644 winml/adapter/WinMLAdapter.cpp delete mode 100644 winml/adapter/WinMLAdapter.h delete mode 100644 winml/adapter/WinMLAdapterErrors.h delete mode 100644 winml/adapter/ZeroCopyInputStreamWrapper.cpp delete mode 100644 winml/adapter/ZeroCopyInputStreamWrapper.h rename winml/adapter/{AbiCustomRegistryImpl.cpp => abi_custom_registry_impl.cpp} (98%) rename winml/adapter/{AbiCustomRegistryImpl.h => abi_custom_registry_impl.h} (93%) create mode 100644 winml/adapter/winml_adapter_apis.h create mode 100644 winml/adapter/winml_adapter_c_api.cpp create mode 100644 winml/adapter/winml_adapter_c_api.h create mode 100644 winml/adapter/winml_adapter_dml.cpp create mode 100644 winml/adapter/winml_adapter_environment.cpp create mode 100644 winml/adapter/winml_adapter_execution_provider.cpp create mode 100644 winml/adapter/winml_adapter_model.cpp create mode 100644 winml/adapter/winml_adapter_model.h create mode 100644 winml/adapter/winml_adapter_session.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h rename winml/{adapter/FeatureDescriptorFactory.cpp => lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp} (58%) create mode 100644 winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeEngine.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeEngine.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeEnvironment.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeErrors.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeModel.cpp create mode 100644 winml/lib/Api.Ort/OnnxruntimeModel.h create mode 100644 winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h create mode 100644 winml/lib/Api.Ort/inc/OnnxruntimeProvider.h create mode 100644 winml/lib/Api.Ort/pch.h create mode 100644 winml/lib/Common/inc/iengine.h diff --git a/cmake/winml.cmake b/cmake/winml.cmake index e7a4c6374efa5..dba8d02b12805 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -8,12 +8,13 @@ include(winml_cppwinrt.cmake) # get the current nuget sdk kit directory get_sdk(sdk_folder sdk_version) set(target_folder ONNXRuntime/winml) +set(winml_adapter_dir ${REPO_ROOT}/winml/adapter) set(winml_api_root ${REPO_ROOT}/winml/api) set(winml_dll_dir ${REPO_ROOT}/winml/dll) set(winml_lib_dir ${REPO_ROOT}/winml/lib) set(winml_lib_api_dir ${REPO_ROOT}/winml/lib/api) -set(winml_adapter_dir ${REPO_ROOT}/winml/adapter) set(winml_lib_api_image_dir ${REPO_ROOT}/winml/lib/api.image) +set(winml_lib_api_ort_dir ${REPO_ROOT}/winml/lib/api.ort) set(winml_lib_common_dir ${REPO_ROOT}/winml/lib/common) set(winml_lib_telemetry_dir ${REPO_ROOT}/winml/lib/telemetry) @@ -116,32 +117,102 @@ set_target_properties(winml_lib_telemetry # Link libraries target_link_libraries(winml_lib_telemetry PRIVATE wil) +########################### +# Add winml_lib_ort +########################### + +list(APPEND winml_lib_api_ort_files + ${winml_lib_api_ort_dir}/inc/OnnxruntimeProvider.h + ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeCpuSessionBuilder.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.h + ${winml_lib_api_ort_dir}/OnnxruntimeDescriptorConverter.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEngine.h + ${winml_lib_api_ort_dir}/OnnxruntimeEngine.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeEngineBuilder.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.h + ${winml_lib_api_ort_dir}/OnnxruntimeEnvironment.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeModel.h + ${winml_lib_api_ort_dir}/OnnxruntimeModel.cpp + ${winml_lib_api_ort_dir}/OnnxruntimeSessionBuilder.h + ${winml_lib_api_ort_dir}/pch.h + ) + +if (onnxruntime_USE_DML) + list(APPEND winml_lib_api_ort_files + ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.h + ${winml_lib_api_ort_dir}/OnnxruntimeDmlSessionBuilder.cpp + ) +endif(onnxruntime_USE_DML) + +# Add static library that will be archived/linked for both static/dynamic library +add_library(winml_lib_ort STATIC ${winml_lib_api_ort_files}) + +# Compiler options +target_compile_features(winml_lib_ort PRIVATE cxx_std_17) +target_compile_options(winml_lib_ort PRIVATE /GR- /await /wd4238) + +# Compiler definitions +target_compile_definitions(winml_lib_ort PRIVATE PLATFORM_WINDOWS) +target_compile_definitions(winml_lib_ort PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators + +# Specify the usage of a precompiled header +target_precompiled_header(winml_lib_ort pch.h) + +# Includes +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers + +target_include_directories(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + +target_include_directories(winml_lib_ort PRIVATE ${REPO_ROOT}/winml) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_dir}) # needed for generated headers +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_core_dir}) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_api_ort_dir}) +target_include_directories(winml_lib_ort PRIVATE ${winml_lib_common_dir}/inc) +target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_INCLUDE_DIR}) +target_include_directories(winml_lib_ort PRIVATE ${ONNXRUNTIME_ROOT}) + +set_target_properties(winml_lib_ort + PROPERTIES + FOLDER + ${target_folder}) + +# Add deps +add_dependencies(winml_lib_ort winml_sdk_cppwinrt) +add_dependencies(winml_lib_ort winml_api) +add_dependencies(winml_lib_ort winml_api_native) +add_dependencies(winml_lib_ort winml_api_native_internal) + +# Link libraries +target_link_libraries(winml_lib_ort PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/packages/DirectML.0.0.1/build/DirectML.targets) +target_link_libraries(winml_lib_ort PRIVATE wil) + + ########################### # Add winml_adapter ########################### list(APPEND winml_adapter_files - ${winml_adapter_dir}/CpuOrtSessionBuilder.cpp - ${winml_adapter_dir}/CpuOrtSessionBuilder.h - ${winml_adapter_dir}/CustomRegistryHelper.h - ${winml_adapter_dir}/FeatureDescriptorFactory.cpp - ${winml_adapter_dir}/FeatureDescriptorFactory.h - ${winml_adapter_dir}/LotusEnvironment.cpp - ${winml_adapter_dir}/LotusEnvironment.h ${winml_adapter_dir}/pch.h - ${winml_adapter_dir}/WinMLAdapter.cpp - ${winml_adapter_dir}/WinMLAdapter.h - ${winml_adapter_dir}/ZeroCopyInputStreamWrapper.cpp - ${winml_adapter_dir}/ZeroCopyInputStreamWrapper.h + ${winml_adapter_dir}/winml_adapter_apis.h + ${winml_adapter_dir}/winml_adapter_c_api.h + ${winml_adapter_dir}/winml_adapter_c_api.cpp + ${winml_adapter_dir}/winml_adapter_dml.cpp + ${winml_adapter_dir}/winml_adapter_environment.cpp + ${winml_adapter_dir}/winml_adapter_execution_provider.cpp + ${winml_adapter_dir}/winml_adapter_model.cpp + ${winml_adapter_dir}/winml_adapter_model.h + ${winml_adapter_dir}/winml_adapter_session.cpp ) - + if (onnxruntime_USE_DML) list(APPEND winml_adapter_files - ${winml_adapter_dir}/AbiCustomRegistryImpl.cpp - ${winml_adapter_dir}/AbiCustomRegistryImpl.h - ${winml_adapter_dir}/DmlOrtSessionBuilder.cpp - ${winml_adapter_dir}/DmlOrtSessionBuilder.h - ) + ${winml_adapter_dir}/abi_custom_registry_impl.cpp + ${winml_adapter_dir}/abi_custom_registry_impl.h + ) endif(onnxruntime_USE_DML) add_library(winml_adapter ${winml_adapter_files}) @@ -329,6 +400,7 @@ target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir}) target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_dir}/pch) target_include_directories(winml_lib_api PRIVATE ${winml_adapter_dir}) target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_image_dir}/inc) +target_include_directories(winml_lib_api PRIVATE ${winml_lib_api_ort_dir}/inc) target_include_directories(winml_lib_api PRIVATE ${winml_lib_telemetry_dir}/inc) target_include_directories(winml_lib_api PRIVATE ${winml_lib_common_dir}/inc) @@ -370,6 +442,19 @@ endif(onnxruntime_USE_DML) ########################### add_library(winml_lib_common STATIC + ${winml_lib_common_dir}/inc/common.h + ${winml_lib_common_dir}/inc/CommonDeviceHelpers.h + ${winml_lib_common_dir}/inc/cppwinrt_onnx.h + ${winml_lib_common_dir}/inc/dx.h + ${winml_lib_common_dir}/inc/errors.h + ${winml_lib_common_dir}/inc/iengine.h + ${winml_lib_common_dir}/inc/NamespaceAliases.h + ${winml_lib_common_dir}/inc/onnx.h + ${winml_lib_common_dir}/inc/PheonixSingleton.h + ${winml_lib_common_dir}/inc/StringHelpers.h + ${winml_lib_common_dir}/inc/WinMLTelemetryHelper.h + ${winml_lib_common_dir}/inc/WinML_Lock.h + ${winml_lib_common_dir}/inc/winrt_headers.h ${winml_lib_common_dir}/CommonDeviceHelpers.cpp ) @@ -448,6 +533,7 @@ target_include_directories(winml_dll PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/s target_include_directories(winml_dll PRIVATE ${winml_dll_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_dir}/impl) +target_include_directories(winml_dll PRIVATE ${winml_lib_api_ort_dir}/inc) target_include_directories(winml_dll PRIVATE ${winml_adapter_dir}) target_include_directories(winml_dll PRIVATE ${winml_lib_api_image_dir}/inc) target_include_directories(winml_dll PRIVATE ${winml_lib_telemetry_dir}/inc) @@ -514,6 +600,7 @@ target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) target_link_libraries(winml_dll PRIVATE winml_lib_api) target_link_libraries(winml_dll PRIVATE winml_lib_image) +target_link_libraries(winml_dll PRIVATE winml_lib_ort) target_link_libraries(winml_dll PRIVATE winml_lib_telemetry) target_link_libraries(winml_dll PRIVATE delayimp.lib) target_link_libraries(winml_dll PRIVATE ${DBGHELP}) diff --git a/include/onnxruntime/core/providers/winml/winml_provider_factory.h b/include/onnxruntime/core/providers/winml/winml_provider_factory.h index b4d4a754d2460..b08b42e310e41 100644 --- a/include/onnxruntime/core/providers/winml/winml_provider_factory.h +++ b/include/onnxruntime/core/providers/winml/winml_provider_factory.h @@ -1,14 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "onnxruntime_c_api.h" -#ifdef __cplusplus -#include -using namespace Windows::AI::MachineLearning::Adapter; -#else -struct IWinMLAdapter; -typedef struct IWinMLAdapter IWinMLAdapter; -#endif +#include "onnxruntime_c_api.h" -ORT_EXPORT STDAPI OrtGetWinMLAdapter(IWinMLAdapter** adapter); +struct WinmlAdapterApi; +typedef struct WinmlAdapterApi WinmlAdapterApi; +ORT_EXPORT const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(_In_ const OrtApi* ort_api) NO_EXCEPTION; \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 176b988ad0f3b..ee05614229f05 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -156,6 +156,8 @@ ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); ORT_RUNTIME_CLASS(CustomOpDomain); +ORT_RUNTIME_CLASS(MapTypeInfo); +ORT_RUNTIME_CLASS(SequenceTypeInfo); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c466b0cb8a79c..a97a5d413f904 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -75,17 +75,11 @@ ORT_DEFINE_RELEASE(Value); // This is used internally by the C++ API. This is the common base class used by the wrapper objects. template struct Base { - Base() { - p_ = nullptr; - } + Base() = default; Base(T* p) : p_{p} { if (!p) throw Ort::Exception("Allocation failure", ORT_FAIL); } - ~Base() { - if (p_ != nullptr) { - OrtRelease(p_); - } - } + ~Base() { OrtRelease(p_); } operator T*() { return p_; } operator const T*() const { return p_; } @@ -96,19 +90,12 @@ struct Base { return p; } - T** put() noexcept { - assert(p_ == nullptr); - return &p_; - } - protected: Base(const Base&) = delete; Base& operator=(const Base&) = delete; Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } void operator=(Base&& v) noexcept { - if (p_ != nullptr) { - OrtRelease(p_); - } + OrtRelease(p_); p_ = v.p_; v.p_ = nullptr; } @@ -275,7 +262,6 @@ struct Value : Base { size_t GetStringTensorDataLength() const; void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; - std::vector GetStrings(); template T* GetTensorMutableData(); @@ -306,9 +292,6 @@ struct MemoryInfo : Base { MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); explicit MemoryInfo(OrtMemoryInfo* p) : Base{p} {} - - const char* Name() const; - OrtMemType MemType() const; }; // @@ -371,4 +354,4 @@ struct CustomOpBase : OrtCustomOp { } // namespace Ort -#include "onnxruntime_cxx_inline.h" +#include "onnxruntime_cxx_inline.h" \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index be2fe6bf9e2c2..f6fb350171f01 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -76,18 +76,6 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O ThrowOnError(Global::api_.CreateMemoryInfo(name, type, id, mem_type, &p_)); } -inline const char* MemoryInfo::Name() const { - const char* out = nullptr; - ThrowOnError(Global::api_.MemoryInfoGetName(p_, &out)); - return out; -} - -inline OrtMemType MemoryInfo::MemType() const { - OrtMemType out; - ThrowOnError(Global::api_.MemoryInfoGetMemType(p_, &out)); - return out; -} - inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) { ThrowOnError(Global::api_.CreateEnv(default_warning_level, logid, &p_)); } @@ -357,21 +345,6 @@ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -template <> -inline Value Value::CreateTensor(const OrtMemoryInfo*, std::string* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { - // convert the array of std::string to an array of const char * - std::vector string_vector; - for (size_t i = 0; i < p_data_element_count; ++i) { - string_vector.push_back(p_data[i].c_str()); - } - // now make an empty tensor using the default allocator (strings have to make a copy) - AllocatorWithDefaultOptions allocator; - auto tensor = Value::CreateTensor(static_cast(allocator), shape, shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); - // now fill the string data - ThrowOnError(GetApi().FillStringTensor(tensor, string_vector.data(), string_vector.size())); - return tensor; -} - inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; @@ -444,33 +417,6 @@ inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, si ThrowOnError(Global::api_.GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); } -inline std::vector Value::GetStrings() { - std::vector out; - // make sure this is an array of strings - auto shape = this->GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - if (shape.size() != 1) throw Ort::Exception("shape.size() != 1", ORT_INVALID_ARGUMENT); - // make a big buffer to hold all the string data - size_t buflen = this->GetStringTensorDataLength(); - std::vector buf(buflen); - std::vector offsets(shape[0]); - this->GetStringTensorContent(buf.data(), buf.size(), offsets.data(), offsets.size()); - // now go build all the strings - for (auto i = 0; i < shape[0]; ++i) { - std::string str; - size_t strlen = 0; - // are we on the last one? - if (i == (shape[0] - 1ll)) { - strlen = buflen - offsets[i]; - } else { - strlen = offsets[i + 1ll] - offsets[i]; - } - str.append(reinterpret_cast(buf.data() + offsets[i]), strlen); - out.push_back(str); - } - return out; -} - template T* Value::GetTensorMutableData() { T* out; @@ -607,4 +553,4 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, return out; } -} // namespace Ort +} // namespace Ort \ No newline at end of file diff --git a/onnxruntime/core/framework/allocatormgr.cc b/onnxruntime/core/framework/allocatormgr.cc index f4258d5a6a889..a38d89a9e2bb8 100644 --- a/onnxruntime/core/framework/allocatormgr.cc +++ b/onnxruntime/core/framework/allocatormgr.cc @@ -29,9 +29,4 @@ AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id return AllocatorPtr(std::move(device_allocator)); } -DeviceAllocatorRegistry& DeviceAllocatorRegistry::Instance() { - static DeviceAllocatorRegistry s_instance; - return s_instance; -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocatormgr.h b/onnxruntime/core/framework/allocatormgr.h index 3985fd4b66a98..aa346fc52f575 100644 --- a/onnxruntime/core/framework/allocatormgr.h +++ b/onnxruntime/core/framework/allocatormgr.h @@ -18,25 +18,4 @@ struct DeviceAllocatorRegistrationInfo { AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id = 0); -class DeviceAllocatorRegistry { - public: - void RegisterDeviceAllocator(std::string&& name, DeviceAllocatorFactory factory, size_t max_mem, - OrtMemType mem_type = OrtMemTypeDefault) { - DeviceAllocatorRegistrationInfo info({mem_type, factory, max_mem}); - device_allocator_registrations_.emplace(std::move(name), std::move(info)); - } - - const std::map& AllRegistrations() const { - return device_allocator_registrations_; - } - - static DeviceAllocatorRegistry& Instance(); - - private: - DeviceAllocatorRegistry() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeviceAllocatorRegistry); - - std::map device_allocator_registrations_; -}; - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc new file mode 100644 index 0000000000000..107cdbbed10c2 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type) noexcept : map_key_type_(map_key_type), map_value_type_(map_value_type, &OrtApis::ReleaseTypeInfo) { +} + +static ONNXTensorElementDataType +ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { + using TensorType = ONNX_NAMESPACE::TensorProto_DataType; + switch (data_type) { + case TensorType::TensorProto_DataType_BOOL: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; } + case TensorType::TensorProto_DataType_STRING: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; } // maps to c++ type std::string + case TensorType::TensorProto_DataType_FLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // maps to c type float + case TensorType::TensorProto_DataType_FLOAT: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; } + case TensorType::TensorProto_DataType_DOUBLE: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; } // maps to c type double + case TensorType::TensorProto_DataType_INT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; } // maps to c type int8_t + case TensorType::TensorProto_DataType_INT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; } // maps to c type int16_t + case TensorType::TensorProto_DataType_INT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } // maps to c type int32_t + case TensorType::TensorProto_DataType_INT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } // maps to c type int64_t + case TensorType::TensorProto_DataType_UINT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; } // maps to c type uint8_t + case TensorType::TensorProto_DataType_UINT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; } // maps to c type uint16_t + case TensorType::TensorProto_DataType_UINT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; } // maps to c type uint32_t + case TensorType::TensorProto_DataType_UINT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; } // maps to c type uint64_t + case TensorType::TensorProto_DataType_COMPLEX64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; } // complex with float32 real and imaginary components + case TensorType::TensorProto_DataType_COMPLEX128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } // complex with float64 real and imaginary components + case TensorType::TensorProto_DataType_BFLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; } // Non-IEEE floating-point format based on IEEE754 single-precision + default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + } +} + +OrtStatus* OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtMapTypeInfo** out) { + auto value_case = type_proto->value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type map!");; + } + + // Get the key type of the map + auto type_proto_map = type_proto->map_type(); + auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); + + // Get the value type of the map + OrtTypeInfo* map_value_type_info = nullptr; + if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_map.value_type(), &map_value_type_info)) + { + return status; + } + + *out = new OrtMapTypeInfo(map_key_type, map_value_type_info); + return nullptr; +} + +OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { + OrtTypeInfo* map_value_type_copy = nullptr; + if (auto status = map_value_type_->Clone(&map_value_type_copy)) + { + return status; + } + *out = new OrtMapTypeInfo(map_key_type_, map_value_type_copy); + return nullptr; +} + +// OrtMapTypeInfo Accessors +ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) { + API_IMPL_BEGIN + *out = map_type_info->map_key_type_; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN + return map_type_info->map_value_type_->Clone(out); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h new file mode 100644 index 0000000000000..2d9297c8cb2d4 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "onnxruntime_c_api.h" + +#include + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtMapTypeInfo { + public: + ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + std::unique_ptr map_value_type_; + + static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtMapTypeInfo** out); + + OrtStatus* Clone(OrtMapTypeInfo** out); + + private: + OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type)noexcept; + OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; + OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; + +}; diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc new file mode 100644 index 0000000000000..a5ee0c9a63bb1 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept : + sequence_key_type_(sequence_key_type, &OrtApis::ReleaseTypeInfo) { +} + +OrtStatus* OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtSequenceTypeInfo** out) { + auto value_case = type_proto->value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type sequence!");; + } + + auto type_proto_sequence = type_proto->sequence_type(); + OrtTypeInfo* sequence_key_type_info = nullptr; + if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_sequence.elem_type(), &sequence_key_type_info)) + { + return status; + } + + *out = new OrtSequenceTypeInfo(sequence_key_type_info); + return nullptr; +} + +OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { + OrtTypeInfo* sequence_key_type_copy = nullptr; + if (auto status = sequence_key_type_->Clone(&sequence_key_type_copy)) + { + return status; + } + *out = new OrtSequenceTypeInfo(sequence_key_type_copy); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) { + API_IMPL_BEGIN + return sequence_type_info->sequence_key_type_->Clone(out); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h new file mode 100644 index 0000000000000..6efa55c8de763 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "onnxruntime_c_api.h" + +#include + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtSequenceTypeInfo { + public: + std::unique_ptr sequence_key_type_; + + OrtStatus* Clone(OrtSequenceTypeInfo** out); + + static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtSequenceTypeInfo** out); + + private: + OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept; + OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete; + OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete; +}; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 080a3518048cb..42e03e802caf1 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,6 +10,11 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_sequence_type_info.h" using onnxruntime::BFloat16; using onnxruntime::DataTypeImpl; @@ -20,11 +25,27 @@ using onnxruntime::TensorShape; namespace on = ONNX_NAMESPACE; +OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { +} + OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) { } +OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtMapTypeInfo* map_type_info1) noexcept : type(type1), map_type_info(map_type_info1) { +} + +OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtSequenceTypeInfo* sequence_type_info1) noexcept : type(type1), sequence_type_info(sequence_type_info1) { +} + OrtTypeInfo::~OrtTypeInfo() { OrtApis::ReleaseTensorTypeAndShapeInfo(data); + + if (map_type_info) { + OrtApis::ReleaseMapTypeInfo(map_type_info); + } + if (sequence_type_info) { + OrtApis::ReleaseSequenceTypeInfo(sequence_type_info); + } } ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, ONNXType* out) { @@ -37,6 +58,28 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) { + API_IMPL_BEGIN + *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) { + API_IMPL_BEGIN + *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) { + API_IMPL_BEGIN + *out = type_info->denotation.c_str(); + *len = type_info->denotation.size(); + return nullptr; + API_IMPL_END +} + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { delete ptr; } @@ -49,7 +92,7 @@ OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vectorIsTensorSequenceType()) { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); + *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE); return nullptr; } @@ -92,16 +135,14 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr); + *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE); return nullptr; } case on::TypeProto::kMapType: { - *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr); - return nullptr; + return OrtTypeInfo::FromTypeProto(type_proto, out); } case on::TypeProto::kSequenceType: { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); - return nullptr; + return OrtTypeInfo::FromTypeProto(type_proto, out); } // Real Tensor support case on::TypeProto::kTensorType: @@ -204,19 +245,39 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info); } if (st != nullptr) return st; - *out = new OrtTypeInfo(ten_type, info); + auto type_info = new OrtTypeInfo(ten_type, info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kSequenceType: { - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); + OrtSequenceTypeInfo* sequence_type_info = nullptr; + + if (auto status = OrtSequenceTypeInfo::FromTypeProto(input, &sequence_type_info)) { + return status; + } + + auto type_info = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kMapType: { - *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr); + OrtMapTypeInfo* map_type_info = nullptr; + + if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) { + return status; + } + + auto type_info = new OrtTypeInfo(ONNX_TYPE_MAP, map_type_info); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr); + auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE); + type_info->denotation = input->denotation(); + *out = type_info; return nullptr; } break; case on::TypeProto::VALUE_NOT_SET: @@ -227,3 +288,48 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or } return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); } + +OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { + switch (type) { + case ONNX_TYPE_TENSOR: + case ONNX_TYPE_SPARSETENSOR: + { + OrtTensorTypeAndShapeInfo* clone; + if (auto status = data->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_SEQUENCE: + { + OrtSequenceTypeInfo* clone; + if (auto status = sequence_type_info->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_MAP: { + OrtMapTypeInfo* clone; + if (auto status = map_type_info->Clone(&clone)) { + return status; + } + *out = new OrtTypeInfo(type, clone); + (*out)->denotation = denotation; + return nullptr; + } + case ONNX_TYPE_OPAQUE: + { + *out = new OrtTypeInfo(type); + (*out)->denotation = denotation; + return nullptr; + } + default: + // Not implemented + break; + } + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); +} \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index d615840dcb501..3c256aa73d17d 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -14,6 +15,10 @@ namespace ONNX_NAMESPACE { class TypeProto; } +// These types are only present in the winml adapter c api, so they are forward declared. +struct OrtMapTypeInfo; +struct OrtSequenceTypeInfo; + /** * the equivalent of ONNX_NAMESPACE::TypeProto * This class is mainly for the C API @@ -21,19 +26,26 @@ class TypeProto; struct OrtTypeInfo { public: ONNXType type = ONNX_TYPE_UNKNOWN; + std::string denotation; ~OrtTypeInfo(); //owned by this OrtTensorTypeAndShapeInfo* data = nullptr; + OrtMapTypeInfo* map_type_info = nullptr; + OrtSequenceTypeInfo* sequence_type_info = nullptr; OrtTypeInfo(const OrtTypeInfo& other) = delete; OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; + OrtStatus* Clone(OrtTypeInfo** out); + static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out); static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); - static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); private: + OrtTypeInfo(ONNXType type) noexcept; OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; + OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept; + OrtTypeInfo(ONNXType type, OrtSequenceTypeInfo* sequence_type_info) noexcept; }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 088043a159962..64bb11dbcbcfa 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -192,6 +192,11 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const st return GetTensorShapeAndTypeHelper(type, shape, dim_params, out); } +OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) +{ + return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out); +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN onnxruntime::MLDataType type = v->Type(); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 28431a9d614cf..f781160cc6505 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -13,4 +13,6 @@ struct OrtTensorTypeAndShapeInfo { OrtTensorTypeAndShapeInfo() = default; OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; + + OrtStatus* Clone(OrtTensorTypeAndShapeInfo** out); }; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 00dbab1536e86..079c444197289 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -26,14 +26,22 @@ struct DMLProviderFactory : IExecutionProviderFactory { ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; + void SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode); private: ComPtr dml_device_{}; ComPtr cmd_queue_{}; + AllocatorRoundingMode rounding_mode_ = AllocatorRoundingMode::Enabled; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - return Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get()); + Dml::SetDefaultRoundingMode(provider.get(), rounding_mode_); + return provider; +} + +void DMLProviderFactory::SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode) { + rounding_mode_ = rounding_mode; } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, @@ -57,8 +65,12 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID return std::make_shared(dml_device, cmd_queue); } -bool IsSoftwareAdapter(IDXGIAdapter1* adapter) -{ +void DmlConfigureProviderFactoryDefaultRoundingMode(IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode) { + auto dml_prvider_factory = static_cast(factory); + dml_prvider_factory->SetDefaultRoundingMode(rounding_mode); +} + +bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { DXGI_ADAPTER_DESC1 desc; adapter->GetDesc1(&desc); @@ -96,7 +108,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG ComPtr debug_device; - (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure + (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); if (is_d3d12_debug_layer_enabled) { @@ -110,7 +122,6 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dml_device))); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b9d4714e000e4..482bea568a701 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -12,13 +12,13 @@ #include #include "core/common/logging/logging.h" -#include "core/common/logging/sinks/clog_sink.h" #include "core/common/status.h" #include "core/graph/graph.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ml_value.h" #include "core/session/environment.h" +#include "core/session/onnxruntime_env.h" #include "core/framework/callback.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" @@ -49,112 +49,6 @@ using namespace onnxruntime; if (_status) return _status; \ } while (0) -class LoggingWrapper : public ISink { - public: - LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) - : logging_function_(logging_function), logger_param_(logger_param) { - } - - void SendImpl(const Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, - const Capture& message) override { - std::string s = message.Location().ToString(); - logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), - logger_id.c_str(), s.c_str(), message.Message().c_str()); - } - - private: - OrtLoggingFunction logging_function_; - void* logger_param_; -}; - -struct OrtEnv { - public: - struct LoggingManagerConstructionInfo { - LoggingManagerConstructionInfo(OrtLoggingFunction logging_function1, - void* logger_param1, - OrtLoggingLevel default_warning_level1, - const char* logid1) - : logging_function(logging_function1), - logger_param(logger_param1), - default_warning_level(default_warning_level1), - logid(logid1) {} - OrtLoggingFunction logging_function{}; - void* logger_param{}; - OrtLoggingLevel default_warning_level; - const char* logid{}; - }; - - static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, Status& status) { - std::lock_guard lock(m_); - if (!p_instance_) { - std::unique_ptr env; - status = Environment::Create(env); - if (!status.IsOK()) { - return nullptr; - } - - std::unique_ptr lmgr; - std::string name = lm_info.logid; - if (lm_info.logging_function) { - std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, - lm_info.logger_param); - lmgr.reset(new LoggingManager(std::move(logger), - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } else { - lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, - static_cast(lm_info.default_warning_level), - false, - LoggingManager::InstanceType::Default, - &name)); - } - - p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); - } - ++ref_count_; - return p_instance_; - } - - static void Release(OrtEnv* env_ptr) { - if (!env_ptr) { - return; - } - std::lock_guard lock(m_); - ORT_ENFORCE(env_ptr == p_instance_); // sanity check - --ref_count_; - if (ref_count_ == 0) { - delete p_instance_; - p_instance_ = nullptr; - } - } - - LoggingManager* GetLoggingManager() const { - return logging_manager_.get(); - } - - private: - static OrtEnv* p_instance_; - static OrtMutex m_; - static int ref_count_; - - std::unique_ptr value_; - std::unique_ptr logging_manager_; - - OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) - : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { - } - - ~OrtEnv() = default; - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); -}; - -OrtEnv* OrtEnv::p_instance_ = nullptr; -int OrtEnv::ref_count_ = 0; -OrtMutex OrtEnv::m_; - #define TENSOR_READ_API_BEGIN \ API_IMPL_BEGIN \ auto v = reinterpret_cast(value); \ @@ -1451,6 +1345,10 @@ static constexpr OrtApi ort_api_1 = { &OrtApis::ReleaseCustomOpDomain, }; +const OrtApi* GetVersion1Api() { + return &ort_api_1; +} + ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version > 1) return nullptr; @@ -1472,4 +1370,4 @@ ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) { DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) -DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) +DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) \ No newline at end of file diff --git a/onnxruntime/core/session/onnxruntime_env.cc b/onnxruntime/core/session/onnxruntime_env.cc new file mode 100644 index 0000000000000..c37a2543a8eed --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_env.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//this file contains implementations of the C API + +#include + +#include "onnxruntime_env.h" +#include "core/session/ort_apis.h" +#include "core/session/environment.h" +#include "core/common/logging/sinks/clog_sink.h" +#include "core/common/logging/logging.h" +#include "core/session/environment.h" + +using namespace onnxruntime; +using namespace onnxruntime::logging; + +OrtEnv* OrtEnv::p_instance_ = nullptr; +int OrtEnv::ref_count_ = 0; +onnxruntime::OrtMutex OrtEnv::m_; + +LoggingWrapper::LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) + : logging_function_(logging_function), logger_param_(logger_param) { +} + +void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) { + std::string s = message.Location().ToString(); + logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), + logger_id.c_str(), s.c_str(), message.Message().c_str()); +} + +OrtEnv::OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager) + : value_(std::move(value1)), logging_manager_(std::move(logging_manager)) { +} + +OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status) { + std::lock_guard lock(m_); + if (!p_instance_) { + std::unique_ptr env; + status = onnxruntime::Environment::Create(env); + if (!status.IsOK()) { + return nullptr; + } + + std::unique_ptr lmgr; + std::string name = lm_info.logid; + if (lm_info.logging_function) { + std::unique_ptr logger = onnxruntime::make_unique(lm_info.logging_function, + lm_info.logger_param); + lmgr.reset(new LoggingManager(std::move(logger), + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } else { + lmgr.reset(new LoggingManager(std::unique_ptr{new CLogSink{}}, + static_cast(lm_info.default_warning_level), + false, + LoggingManager::InstanceType::Default, + &name)); + } + + p_instance_ = new OrtEnv(std::move(env), std::move(lmgr)); + } + ++ref_count_; + return p_instance_; +} + +void OrtEnv::Release(OrtEnv* env_ptr) { + if (!env_ptr) { + return; + } + std::lock_guard lock(m_); + ORT_ENFORCE(env_ptr == p_instance_); // sanity check + --ref_count_; + if (ref_count_ == 0) { + delete p_instance_; + p_instance_ = nullptr; + } +} + +LoggingManager* OrtEnv::GetLoggingManager() const { + return logging_manager_.get(); +} + +void OrtEnv::SetLoggingManager(std::unique_ptr logging_manager) { + std::lock_guard lock(m_); + logging_manager_ = std::move(logging_manager); +} \ No newline at end of file diff --git a/onnxruntime/core/session/onnxruntime_env.h b/onnxruntime/core/session/onnxruntime_env.h new file mode 100644 index 0000000000000..c93d2937c7a7b --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_env.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/isink.h" +#include "core/platform/ort_mutex.h" +#include "core/common/status.h" + +namespace onnxruntime { +class Environment; +} + +class LoggingWrapper : public onnxruntime::logging::ISink { + public: + LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param); + + void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/ /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) override; + + private: + OrtLoggingFunction logging_function_; + void* logger_param_; +}; + +struct OrtEnv { + public: + struct LoggingManagerConstructionInfo { + LoggingManagerConstructionInfo(OrtLoggingFunction logging_function1, + void* logger_param1, + OrtLoggingLevel default_warning_level1, + const char* logid1) + : logging_function(logging_function1), + logger_param(logger_param1), + default_warning_level(default_warning_level1), + logid(logid1) {} + OrtLoggingFunction logging_function{}; + void* logger_param{}; + OrtLoggingLevel default_warning_level; + const char* logid{}; + }; + + static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, onnxruntime::common::Status& status); + + static void Release(OrtEnv* env_ptr); + + onnxruntime::logging::LoggingManager* GetLoggingManager() const; + + void SetLoggingManager(std::unique_ptr logging_manager); + + private: + static OrtEnv* p_instance_; + static onnxruntime::OrtMutex m_; + static int ref_count_; + + std::unique_ptr value_; + std::unique_ptr logging_manager_; + + OrtEnv(std::unique_ptr value1, std::unique_ptr logging_manager); + ~OrtEnv() = default; + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); +}; \ No newline at end of file diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index cdc1ea7b6900f..4e3bf2274aaf4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -16,6 +16,8 @@ ORT_API(void, ReleaseTypeInfo, OrtTypeInfo*); ORT_API(void, ReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo*); ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*); ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*); +ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*); +ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*); ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg); OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; @@ -144,4 +146,16 @@ ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* c ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +// OrtTypeInfo methods +ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); +ORT_API_STATUS_IMPL(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out); +ORT_API_STATUS_IMPL(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out); + +// OrtMapTypeInfo Accessors +ORT_API_STATUS_IMPL(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); +ORT_API_STATUS_IMPL(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); + +// OrtSequenceTypeInfo Accessors +ORT_API_STATUS_IMPL(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info); + } // namespace OrtApis diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index de7185e091f93..6a10e2f0f6707 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -1,14 +1,4 @@ jobs: -- template: templates/win-ci.yml - parameters: - AgentPool : 'Win-CPU' - DoDebugBuild: 'true' - DoCompliance: 'false' - BuildCommand: '$(Build.SourcesDirectory)\tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --use_tvm --use_automl --enable_pybind --use_mkldnn --use_openmp --use_winml --build_shared_lib --build_csharp --enable_onnx_tests' - JobName: 'Windows_CI_Dev' - DoNugetPack: 'false' - NuPackScript : '' - DoTestCoverage: 'false' - job: 'build' pool: 'Win-CPU-2019' strategy: @@ -66,7 +56,7 @@ jobs: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --build_wheel --use_featurizers --use_dnnl --use_openmp --build_shared_lib --enable_onnx_tests --build_java' + arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --build_wheel --use_featurizers --use_dnnl --use_winml --use_openmp --build_shared_lib --enable_onnx_tests --build_java' workingDirectory: '$(Build.BinariesDirectory)' - task: VSBuild@1 diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 0bf9c923e78b9..3ef5bef8c751f 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -42,7 +42,8 @@ jobs: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --use_featurizers --use_dnnl --build_shared_lib --enable_onnx_tests --use_dml --use_cuda --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' + arguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 16 2019" --msvc_toolset 14.16 --build_wheel --use_featurizers + nnl --build_shared_lib --enable_onnx_tests --use_dml --use_winml --use_cuda --cuda_version=10.0 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.6.5.32\cuda" --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.18362.0' workingDirectory: '$(Build.BinariesDirectory)' - task: VSBuild@1 diff --git a/winml/adapter/CpuOrtSessionBuilder.cpp b/winml/adapter/CpuOrtSessionBuilder.cpp deleted file mode 100644 index 72d09ff022941..0000000000000 --- a/winml/adapter/CpuOrtSessionBuilder.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -// Needed to work around the fact that OnnxRuntime defines ERROR -#ifdef ERROR -#undef ERROR -#endif -#include "core/session/inference_session.h" -// Restore ERROR define -#define ERROR 0 - -#include "CpuOrtSessionBuilder.h" -#include "WinMLAdapter.h" -#include "WinMLAdapterErrors.h" - -// winml includes -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" - -// ort includes -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/optimizer/conv_activation_fusion.h" -#include "core/optimizer/gemm_activation_fusion.h" -#include "core/session/abi_session_options_impl.h" - -using namespace Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -CpuOrtSessionBuilder::CpuOrtSessionBuilder() { - -} - -HRESULT -CpuOrtSessionBuilder::CreateSessionOptions( - OrtSessionOptions** options) try { - RETURN_HR_IF_NULL(E_POINTER, options); - - Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); - Ort::SessionOptions session_options(*options); - - // set the graph optimization level to all (used to be called level 3) - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - - // Onnxruntime will use half the number of concurrent threads supported on the system - // by default. This causes MLAS to not exercise every logical core. - // We force the thread pool size to be maxxed out to ensure that WinML always - // runs the fastest. - session_options.SetIntraOpNumThreads(std::thread::hardware_concurrency()); - - // call release() so the underlying OrtSessionOptions object isn't freed - session_options.release(); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT -CpuOrtSessionBuilder::CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) try { - RETURN_HR_IF_NULL(E_POINTER, p_session); - RETURN_HR_IF_NULL(E_POINTER, pp_provider); - RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); - - // Create the inference session - auto session = std::make_unique(options->value); - - // Create the cpu execution provider - onnxruntime::CPUExecutionProviderInfo xpInfo; -#ifndef _WIN64 - xpInfo.create_arena = false; -#endif - auto cpu_provider = std::make_unique(xpInfo); - - // Cache the provider's raw pointer - *pp_provider = cpu_provider.get(); - - // Register the cpu xp - ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(cpu_provider))); - - // assign the session to the out parameter - auto sessionptr = wil::MakeOrThrow(session.release()); - RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session)); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT -CpuOrtSessionBuilder::Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* /*p_provider*/ -) try { - ORT_THROW_IF_ERROR(p_session->get()->Initialize()); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/CpuOrtSessionBuilder.h b/winml/adapter/CpuOrtSessionBuilder.h deleted file mode 100644 index 700129275f490..0000000000000 --- a/winml/adapter/CpuOrtSessionBuilder.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "WinMLAdapter.h" - -namespace Windows::AI::MachineLearning::Adapter { - -class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - winmla::IOrtSessionBuilder> { - - public: - CpuOrtSessionBuilder(); - - HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions** options) override; - - HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) override; - - HRESULT STDMETHODCALLTYPE Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) override; -}; - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/CustomRegistryHelper.h b/winml/adapter/CustomRegistryHelper.h deleted file mode 100644 index de2987e676447..0000000000000 --- a/winml/adapter/CustomRegistryHelper.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" - -namespace Windows::AI::MachineLearning::Adapter { - -inline std::list> -GetLotusCustomRegistries( - IMLOperatorRegistry* registry) { - if (registry != nullptr) { - // Down-cast to the concrete type. - // The only supported input is the AbiCustomRegistry type. - // Other implementations of IMLOperatorRegistry are forbidden. - auto abi_custom_registry = - static_cast(registry); - - // Get the ORT registry - return abi_custom_registry->GetRegistries(); - } - - return {}; -} - -} // namespace Windows::AI::MachineLearning::Adapter - -#endif USE_DML diff --git a/winml/adapter/DmlOrtSessionBuilder.cpp b/winml/adapter/DmlOrtSessionBuilder.cpp deleted file mode 100644 index 29da1a6332642..0000000000000 --- a/winml/adapter/DmlOrtSessionBuilder.cpp +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -#ifdef USE_DML - -// Needed to work around the fact that OnnxRuntime defines ERROR -#ifdef ERROR -#undef ERROR -#endif -#include "core/session/inference_session.h" -// Restore ERROR define -#define ERROR 0 - -#include "DmlOrtSessionBuilder.h" -#include "WinMLAdapterErrors.h" - -// winml includes -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" -#include "CustomRegistryHelper.h" -#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" -#include "LearningModelDevice.h" -#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" - -// ort includes -#include "core/framework/op_kernel.h" -#include "core/framework/op_node_proto_helper.h" -#include "core/framework/customRegistry.h" -#include "core/framework/data_transfer.h" -#include "core/session/abi_session_options_impl.h" - -using namespace Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -DmlOrtSessionBuilder::DmlOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue) { - device_.copy_from(device); - queue_.copy_from(queue); -} - -HRESULT -DmlOrtSessionBuilder::CreateSessionOptions( - OrtSessionOptions** options) try { - RETURN_HR_IF_NULL(E_POINTER, options); - - Ort::ThrowOnError(Ort::GetApi().CreateSessionOptions(options)); - Ort::SessionOptions session_options(*options); - - // set the graph optimization level to all (used to be called level 3) - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - - // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. - session_options.DisableMemPattern(); - - // call release() so the underlying OrtSessionOptions object isn't freed - session_options.release(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -static HRESULT -RegisterCustomRegistry( - onnxruntime::InferenceSession* p_session, - IMLOperatorRegistry* registry) { - if (registry != nullptr) { - RETURN_HR_IF_NULL(E_POINTER, p_session); - - auto custom_registries = GetLotusCustomRegistries(registry); - - // Register - for (auto& custom_registry : custom_registries) { - ORT_THROW_IF_ERROR(p_session->RegisterCustomRegistry(custom_registry)); - } - } - - return S_OK; -} - -Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { - // Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll - wil::unique_hmodule dmlDll(LoadLibraryW(L"DirectML.dll")); - THROW_LAST_ERROR_IF(!dmlDll); - - auto dmlCreateDevice1Fn = reinterpret_cast( - GetProcAddress(dmlDll.get(), "DMLCreateDevice1")); - THROW_LAST_ERROR_IF(!dmlCreateDevice1Fn); - - DML_CREATE_DEVICE_FLAGS dmlFlags = DML_CREATE_DEVICE_FLAG_NONE; - - // Enable the DML debug layer in DEBUG builds, if the D3D12 debug layer is also enabled -#if _DEBUG - Microsoft::WRL::ComPtr d3d12DebugDevice; - if (SUCCEEDED(d3d12Device->QueryInterface(IID_PPV_ARGS(&d3d12DebugDevice)))) { - d3d12DebugDevice = nullptr; - dmlFlags |= DML_CREATE_DEVICE_FLAG_DEBUG; - } -#endif - - Microsoft::WRL::ComPtr dmlDevice; - THROW_IF_FAILED(dmlCreateDevice1Fn(d3d12Device, dmlFlags, DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dmlDevice))); - - // Keep DirectML.dll loaded by leaking the handle. This is equivalent behavior to if we delay-loaded the DLL. - dmlDll.release(); - - return dmlDevice; -} - -HRESULT DmlOrtSessionBuilder::CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) try { - RETURN_HR_IF_NULL(E_POINTER, p_session); - RETURN_HR_IF_NULL(E_POINTER, pp_provider); - RETURN_HR_IF(E_POINTER, *pp_provider != nullptr); - - auto p_d3d_device = device_.get(); - auto p_queue = queue_.get(); - - Microsoft::WRL::ComPtr dmlDevice = CreateDmlDevice(p_d3d_device); - - std::unique_ptr gpu_provider = Dml::CreateExecutionProvider(dmlDevice.Get(), p_queue); - auto session = std::make_unique(options->value); - - const onnxruntime::Env& env = onnxruntime::Env::Default(); - LUID temp_LUID = p_d3d_device->GetAdapterLuid(); - env.GetTelemetryProvider().LogExecutionProviderEvent(&temp_LUID); - // Cache the provider's raw pointer - *pp_provider = gpu_provider.get(); - - ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(gpu_provider))); - - // assign the session to the out parameter - auto sessionptr = wil::MakeOrThrow(session.release()); - RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session)); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT DmlOrtSessionBuilder::Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) try { - RETURN_HR_IF_NULL(E_INVALIDARG, p_session); - RETURN_HR_IF_NULL(E_INVALIDARG, p_provider); - - // OnnxRuntime uses the default rounding mode when calling the session's allocator. - // During initialization, OnnxRuntime allocates weights, which are permanent across session - // lifetime and can be large, so shouldn't be rounded. - Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Disabled); - - ORT_THROW_IF_ERROR(p_session->get()->Initialize()); - - Dml::SetDefaultRoundingMode(p_provider, AllocatorRoundingMode::Enabled); - - // Flush the D3D12 work from the DML execution provider - Dml::FlushContext(p_provider); - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // namespace Windows::AI::MachineLearning::Adapter - -#endif USE_DML \ No newline at end of file diff --git a/winml/adapter/DmlOrtSessionBuilder.h b/winml/adapter/DmlOrtSessionBuilder.h deleted file mode 100644 index a02d1c21f8800..0000000000000 --- a/winml/adapter/DmlOrtSessionBuilder.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "WinMLAdapter.h" - -namespace Windows::AI::MachineLearning::Adapter { - -class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - winmla::IOrtSessionBuilder> { - - public: - DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue); - - HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions** options) override; - - HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions* options, - winmla::IInferenceSession** p_session, - onnxruntime::IExecutionProvider** pp_provider) override; - - HRESULT STDMETHODCALLTYPE Initialize( - winmla::IInferenceSession* p_session, - onnxruntime::IExecutionProvider* p_provider) override; - - private: - winrt::com_ptr device_; - winrt::com_ptr queue_; -}; - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/FeatureDescriptorFactory.h b/winml/adapter/FeatureDescriptorFactory.h deleted file mode 100644 index 497f92d9cbc8b..0000000000000 --- a/winml/adapter/FeatureDescriptorFactory.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once -#include "pch.h" - -namespace Windows::AI::MachineLearning { - -struct FeatureDescriptorFactory { - FeatureDescriptorFactory( - const std::unordered_map& model_metadata); - - wfc::IVector - CreateDescriptorsFromValueInfoProtos( - const std::vector& value_info_protos); - - private: - const std::unordered_map& metadata_; -}; - -} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/LotusEnvironment.cpp b/winml/adapter/LotusEnvironment.cpp deleted file mode 100644 index 30e3af20c46fc..0000000000000 --- a/winml/adapter/LotusEnvironment.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" -#include "LotusEnvironment.h" -#include "core/platform/windows/TraceLoggingConfig.h" -#include - -bool Windows::AI::MachineLearning::CWinMLLogSink::debug_output_ = false; -void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl( - const onnxruntime::logging::Timestamp& timestamp, - const std::string& logger_id, - const onnxruntime::logging::Capture& message) { - // ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry. - switch (message.Severity()) { - case (onnxruntime::logging::Severity::kFATAL): //Telemetry - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str()), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); - break; - case (onnxruntime::logging::Severity::kERROR): //Telemetry - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_ERROR), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str()), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); - break; - case (onnxruntime::logging::Severity::kWARNING): - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_WARNING), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - break; - case (onnxruntime::logging::Severity::kINFO): - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - break; - case (onnxruntime::logging::Severity::kVERBOSE): - __fallthrough; //Default is Verbose too. - default: - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "WinMLLogSink", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(message.Category()), - TraceLoggingUInt32((UINT32)message.Severity()), - TraceLoggingString(message.Message().c_str()), - TraceLoggingString(message.Location().ToString(onnxruntime::CodeLocation::kFilenameAndPath).c_str())); - } - if (debug_output_) { - OutputDebugStringA(std::string(message.Message() + "\r\n").c_str()); - } -} - -void Windows::AI::MachineLearning::CWinMLLogSink::SendProfileEvent(onnxruntime::profiling::EventRecord& eventRecord) const { - if (eventRecord.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) { - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "OnnxRuntimeProfiling", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(onnxruntime::profiling::event_categor_names_[eventRecord.cat], "Category"), - TraceLoggingInt64(eventRecord.dur, "Duration (us)"), - TraceLoggingInt64(eventRecord.ts, "Time Stamp (us)"), - TraceLoggingString(eventRecord.name.c_str(), "Event Name"), - TraceLoggingInt32(eventRecord.pid, "Process ID"), - TraceLoggingInt32(eventRecord.tid, "Thread ID"), - TraceLoggingString(eventRecord.args["op_name"].c_str(), "Operator Name"), - TraceLoggingString(eventRecord.args["provider"].c_str(), "Execution Provider")); - } else { - TraceLoggingWrite( - winmla::winml_trace_logging_provider, - "OnnxRuntimeProfiling", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), - TraceLoggingString(onnxruntime::profiling::event_categor_names_[eventRecord.cat], "Category"), - TraceLoggingInt64(eventRecord.dur, "Duration (us)"), - TraceLoggingInt64(eventRecord.ts, "Time Stamp (us)"), - TraceLoggingString(eventRecord.name.c_str(), "Event Name"), - TraceLoggingInt32(eventRecord.pid, "Process ID"), - TraceLoggingInt32(eventRecord.tid, "Thread ID")); - } -} diff --git a/winml/adapter/LotusEnvironment.h b/winml/adapter/LotusEnvironment.h deleted file mode 100644 index 37bb8ad7ed584..0000000000000 --- a/winml/adapter/LotusEnvironment.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once -#include "core/common/logging/isink.h" -#include -#include -#include "WinMLAdapter.h" - -#pragma warning(push) -#pragma warning(disable : 4505) - -namespace Windows { -namespace AI { -namespace MachineLearning { -class CWinMLLogSink : public onnxruntime::logging::ISink { - public: - CWinMLLogSink() { - } - static void EnableDebugOutput() { - debug_output_ = true; - OutputDebugStringW(L"Windows.AI.MachineLearning: Debug Output Enabled \r\n"); - } - void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const; - void SendImpl(const onnxruntime::logging::Timestamp& timestamp, const std::string& logger_id, const onnxruntime::logging::Capture& message); - - private: - static bool debug_output_; -}; -// TODO: a bug in ORT requires a logging manager. This function registers a static singleton logger as "default" -inline onnxruntime::logging::LoggingManager& DefaultLoggingManager() { - // create a CLog based default logging manager - static std::string default_logger_id{"Default"}; - static onnxruntime::logging::LoggingManager default_logging_manager{ - std::unique_ptr{new CWinMLLogSink()}, - onnxruntime::logging::Severity::kVERBOSE, - false, - onnxruntime::logging::LoggingManager::InstanceType::Default, - &default_logger_id, - MAXINT32}; - - return default_logging_manager; -} - -class LotusEnvironment { - public: - LotusEnvironment() { - const HRESULT etw_status = TraceLoggingRegister(winmla::winml_trace_logging_provider); - if (FAILED(etw_status)) { - throw std::runtime_error("WinML TraceLogging registration failed. Logging will be broken: " + std::to_string(etw_status)); - } - - // TODO: Do we need to call this or just define the method? - default_logging_manager_ = &DefaultLoggingManager(); - - if (!onnxruntime::Environment::Create(lotus_environment_).IsOK()) { - throw winrt::hresult_error(E_FAIL); - } - - auto allocatorMap = onnxruntime::DeviceAllocatorRegistry::Instance().AllRegistrations(); - if (allocatorMap.find("Cpu") == allocatorMap.end()) { - onnxruntime::DeviceAllocatorRegistry::Instance().RegisterDeviceAllocator( - "Cpu", - [](int) { return std::make_unique(); }, - std::numeric_limits::max()); - } - } - - ~LotusEnvironment() { - TraceLoggingUnregister(winmla::winml_trace_logging_provider); - } - - const onnxruntime::logging::Logger* GetDefaultLogger() { - return &default_logging_manager_->DefaultLogger(); - } - - private: - std::unique_ptr lotus_environment_; - onnxruntime::logging::LoggingManager* default_logging_manager_; -}; - -namespace ExecutionProviders { -__declspec(selectany) const char* CPUExecutionProvider = "CPUExecutionProvider"; -} - -} // namespace MachineLearning -} // namespace AI -} // namespace Windows - -#pragma warning(pop) \ No newline at end of file diff --git a/winml/adapter/WinMLAdapter.cpp b/winml/adapter/WinMLAdapter.cpp deleted file mode 100644 index eef56c60799a8..0000000000000 --- a/winml/adapter/WinMLAdapter.cpp +++ /dev/null @@ -1,759 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" -#include "WinMLAdapter.h" -#include "WinMLAdapterErrors.h" -#include "CustomRegistryHelper.h" -#include "PheonixSingleton.h" -#include "LotusEnvironment.h" -#include "AbiCustomRegistryImpl.h" - -#ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" -#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" -#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" -#include "DmlOrtSessionBuilder.h" -#endif USE_DML - -#include "LearningModelDevice.h" -#include "TensorFeatureDescriptor.h" -#include "ImageFeatureDescriptor.h" -#include "api.image/inc/D3DDeviceCache.h" -#include "Common/inc/WinMLTelemetryHelper.h" - -#include "CpuOrtSessionBuilder.h" - -#include -#include - -#include "ZeroCopyInputStreamWrapper.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" - -#include "FeatureDescriptorFactory.h" -#include "core\framework\utils.h" -#include "core\framework\session_state.h" -#include "core/providers/winml/winml_provider_factory.h" - -using namespace winrt::Windows::AI::MachineLearning; - -namespace Windows::AI::MachineLearning::Adapter { - -// Define winml trace logging provider with WinML GUID -TRACELOGGING_DEFINE_PROVIDER( - winml_trace_logging_provider, - WINML_PROVIDER_DESC, - WINML_PROVIDER_GUID); - -// ORT intentionally requires callers derive from their session class to access -// the protected methods used below. -class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession { - public: - onnxruntime::common::Status - Load(std::unique_ptr p_model_proto) { - return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); - } - const onnxruntime::SessionState& GetSessionState() { - return *session_state_; - } -}; - -class ModelProto : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelProto> { - public: - ModelProto::ModelProto(onnx::ModelProto* model_proto) : model_proto_(model_proto) { - } - - onnx::ModelProto* STDMETHODCALLTYPE get() noexcept override { - return model_proto_.get(); - } - - onnx::ModelProto* STDMETHODCALLTYPE detach() noexcept override { - return model_proto_.release(); - } - - private: - std::unique_ptr model_proto_; -}; // class ModelProto - -class ModelInfo : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IModelInfo> { - private: - std::string author_; - std::string name_; - std::string domain_; - std::string description_; - int64_t version_; - std::unordered_map model_metadata_; - wfc::IVector input_features_; - wfc::IVector output_features_; - - public: - ModelInfo(const onnx::ModelProto* model_proto) { - Initialize(model_proto); - } - - const char* STDMETHODCALLTYPE author() noexcept override { - return author_.c_str(); - } - - const char* STDMETHODCALLTYPE name() noexcept override { - return name_.c_str(); - } - - const char* STDMETHODCALLTYPE domain() noexcept override { - return domain_.c_str(); - } - - const char* STDMETHODCALLTYPE description() noexcept override { - return description_.c_str(); - } - - int64_t STDMETHODCALLTYPE version() noexcept override { - return version_; - } - - HRESULT STDMETHODCALLTYPE GetModelMetadata( - ABI::Windows::Foundation::Collections::IMapView** metadata) override try { - *metadata = nullptr; - std::unordered_map map_copy; - for (auto& pair : model_metadata_) { - auto key = WinML::Strings::HStringFromUTF8(pair.first); - auto map_value = WinML::Strings::HStringFromUTF8(pair.second); - map_copy.emplace(std::move(key), std::move(map_value)); - } - auto out = winrt::single_threaded_map( - std::move(map_copy)); - - winrt::copy_to_abi(out.GetView(), *(void**)metadata); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetInputFeatures( - ABI::Windows::Foundation::Collections::IVectorView** features) override try { - *features = nullptr; - winrt::copy_to_abi(input_features_.GetView(), *(void**)features); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetOutputFeatures( - ABI::Windows::Foundation::Collections::IVectorView** features) override try { - *features = nullptr; - winrt::copy_to_abi(output_features_.GetView(), *(void**)features); - return S_OK; - } - WINMLA_CATCH_ALL_COM - - static std::vector - GetAllNodeOutputs(const onnx::ModelProto& model_proto) { - std::vector nodes_outputs; - auto& graph = model_proto.graph(); - auto& nodes = graph.node(); - for (auto& node : nodes) { - for (auto& node_output : node.output()) { - nodes_outputs.push_back(node_output.c_str()); - } - } - return nodes_outputs; - } - - static std::vector - GetInitializers(const onnx::ModelProto& model_proto) { - std::vector initializers; - auto& graph = model_proto.graph(); - auto& graph_initializers = graph.initializer(); - for (auto& initializer : graph_initializers) { - initializers.push_back(initializer.name().c_str()); - } - return initializers; - } - - static std::vector - GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { - auto initializers = GetInitializers(model_proto); - - std::vector inputs_without_initializers; - auto& graph = model_proto.graph(); - auto& inputs = graph.input(); - for (auto& input : inputs) { - if (input.has_name() && input.has_type()) { - auto found_it = std::find_if( - std::begin(initializers), - std::end(initializers), - [&](auto& initializer) { - return std::strcmp(initializer, input.name().c_str()) == 0; - }); - - auto is_initializer = found_it != std::end(initializers); - if (!is_initializer) { - inputs_without_initializers.push_back(&input); - } - } - } - return inputs_without_initializers; - } - - static std::vector GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; - auto& graph = model_proto.graph(); - auto& outputs = graph.output(); - for (auto& output : outputs) { - if (output.has_name() && output.has_type()) { - outputs_with_name.push_back(&output); - } - } - return outputs_with_name; - } - - private: - void Initialize(const onnx::ModelProto* model_proto) { - // metadata - for (auto& prop : model_proto->metadata_props()) { - model_metadata_[prop.key()] = prop.value(); - } - - WinML::FeatureDescriptorFactory builder(model_metadata_); - - // Create inputs - auto inputs = GetInputsWithoutInitializers(*model_proto); - input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); - - // Create outputs - auto outputs = GetOutputs(*model_proto); - output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); - - // author - auto has_producer_name = model_proto->has_producer_name(); - author_ = has_producer_name - ? model_proto->producer_name() - : ""; - - // domain - auto has_domain = model_proto->has_domain(); - domain_ = has_domain - ? model_proto->domain() - : ""; - - // name - auto has_graph = model_proto->has_graph(); - auto graph_has_name = model_proto->graph().has_name(); - auto is_name_available = has_graph && graph_has_name; - name_ = is_name_available - ? model_proto->graph().name() - : ""; - - // description - auto has_description = model_proto->has_doc_string(); - description_ = has_description - ? model_proto->doc_string() - : ""; - - // version - auto has_version = model_proto->has_model_version(); - version_ = has_version - ? model_proto->model_version() - : 0; - } -}; // class ModelInfo - -class WinMLAdapter : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, - IWinMLAdapter> { - private: - // TODO: Making this static is only temporary. A fix addressing the resulting the memory leaks is needed. - static std::shared_ptr lotus_environment_; - - public: - WinMLAdapter() { - if (lotus_environment_ == nullptr) { - lotus_environment_ = PheonixSingleton(); - } - } - // factory methods for creating an ort model from a path - HRESULT STDMETHODCALLTYPE CreateModelProto( - const char* path, - IModelProto** model_proto) override try { - int file_descriptor; - _set_errno(0); // clear errno - _sopen_s( - &file_descriptor, - path, - O_RDONLY | _O_SEQUENTIAL | _O_BINARY, - _SH_DENYWR, - _S_IREAD | _S_IWRITE); - - errno_t err = 0; - _get_errno(&err); - THROW_HR_IF_MSG( - __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), - err == ENOENT, - "File not found: %s", - path); - - THROW_HR_IF_MSG( - E_FAIL, - 0 > file_descriptor, - "Failed"); //errno - - auto stream = google::protobuf::io::FileInputStream(file_descriptor); - stream.SetCloseOnDelete(true); - - auto model_proto_inner = new onnx::ModelProto(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&stream) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - // factory methods for creating an ort model from a stream - HRESULT STDMETHODCALLTYPE CreateModelProto( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, - IModelProto** model_proto) override try { - ZeroCopyInputStreamWrapper wrapper(stream_reference); - - auto model_proto_inner = std::make_unique(); - THROW_HR_IF_MSG( - E_INVALIDARG, - model_proto_inner->ParseFromZeroCopyStream(&wrapper) == false, - "The stream failed to parse."); - - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner.release()); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - // factory methods for creating an ort model from a model_proto - HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto* model_proto_in, IModelProto** model_proto) override try { - auto model_proto_inner = new onnx::ModelProto(*model_proto_in->get()); - auto model_proto_outer = wil::MakeOrThrow(model_proto_inner); - return model_proto_outer.CopyTo(__uuidof(IModelProto), reinterpret_cast(model_proto)); - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto* model_proto, IModelInfo** model_info) override try { - auto model_info_outer = wil::MakeOrThrow(model_proto->get()); - return model_info_outer.CopyTo(__uuidof(IModelInfo), reinterpret_cast(model_info)); - } - WINMLA_CATCH_ALL_COM - - void STDMETHODCALLTYPE EnableDebugOutput() override try { - WinML::CWinMLLogSink::EnableDebugOutput(); - } - WINMLA_CATCH_ALL_DONOTHING - - static bool IsFeatureDescriptorFp16( - winml::ILearningModelFeatureDescriptor descriptor) { - if (auto imageFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == imageFeatureDescriptor.TensorKind(); - } - - if (auto tensorFeatureDescriptor = descriptor.try_as()) { - return TensorKind::Float16 == tensorFeatureDescriptor.TensorKind(); - } - - return false; - } - - HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( - winml::LearningModel const& model, - IModelProto* p_model_proto, - bool is_float16_supported) override try { - if (!is_float16_supported) { - auto& graph = p_model_proto->get()->graph(); - - // The model will not contain fp16 operations if: - // 1. The model has no fp16 inputs - // 2. The model has no fp16 initializers - // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator - // 4. The model does not have any fp16 outputs - - // 1. Ensure that The model has no fp16 inputs - for (auto descriptor : model.InputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit input (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } - - // 2. Ensure that the model has no fp16 initializers - for (int i = 0; i < graph.node_size(); i++) { - auto node = graph.node(i); - if (node.op_type() == "Cast" && node.domain().empty()) { - for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { - auto attribute = node.attribute(attribIndex); - if (attribute.name() == "to") { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float Cast Op (%s), but the current device does not support 16-bit float.", - node.name().c_str()); - } - } - } - } - - // 3. Ensure that the model does not create any fp16 intermediary - // tensors via the Cast (to float16) operator - for (int i = 0; i < graph.initializer_size(); i++) { - auto initializer = graph.initializer(i); - - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16, - "The model contains a 16-bit float initializer (%s), but the current device does not support 16-bit float.", - initializer.name().c_str()); - } - - // 4. Ensure that the model does not have any fp16 outputs - for (auto descriptor : model.OutputFeatures()) { - THROW_HR_IF_MSG( - DXGI_ERROR_UNSUPPORTED, - IsFeatureDescriptorFp16(descriptor), - "The model contains a 16-bit output (%ls), but the current device does not support 16-bit float.", - descriptor.Name().c_str()); - } - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider* provider, void* allocation) override try { -#ifdef USE_DML - auto d3dResource = - Dml::GetD3D12ResourceFromAllocation( - provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), - allocation); - return d3dResource; -#else - return nullptr; -#endif USE_DML - } catch (...) { - return nullptr; - } - - static onnxruntime::MLDataType GetType(winml::TensorKind kind) { - switch (kind) { - case winml::TensorKind::Float: - return onnxruntime::DataTypeImpl::GetType(); - case winml::TensorKind::Float16: - return onnxruntime::DataTypeImpl::GetType(); - }; - return nullptr; - } - - // factory method for creating an ortsessionbuilder from a device - HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue, - IOrtSessionBuilder** session_builder) override try { - if (device == nullptr) { - auto builder = wil::MakeOrThrow(); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } -#ifdef USE_DML - else { - auto builder = wil::MakeOrThrow(device, queue); - return builder.CopyTo(__uuidof(IOrtSessionBuilder), reinterpret_cast(session_builder)); - } -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try { - *key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - auto type = ort_value->Type(); - if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue* ort_value, ONNXTensorElementDataType* key_type, ONNXTensorElementDataType* value_type) override try { - *key_type = *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - auto type = ort_value->Type(); - if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (type == onnxruntime::DataTypeImpl::GetType()) { - *key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - *value_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) override try { -#ifdef USE_DML - auto impl = wil::MakeOrThrow(); - *registry = impl.Detach(); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative* operator_provider_native, IMLOperatorRegistry** registry) override try { -#ifdef USE_DML - // Retrieve the "operator abi" registry. - winrt::com_ptr operator_registry; - THROW_IF_FAILED(operator_provider_native->GetRegistry(operator_registry.put())); - *registry = operator_registry.detach(); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) override try { -#ifdef USE_DML - return Dml::CreateGPUAllocationFromD3DResource(pResource); -#else - return nullptr; -#endif USE_DML - } catch (...) { - return nullptr; - } - - void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) override try { -#ifdef USE_DML - Dml::FreeGPUAllocation(ptr); -#endif USE_DML - } - WINMLA_CATCH_ALL_DONOTHING - - HRESULT STDMETHODCALLTYPE CopyTensor( - onnxruntime::IExecutionProvider* provider, - OrtValue* src, - OrtValue* dst) override try { -#ifdef USE_DML - ORT_THROW_IF_ERROR(Dml::CopyTensor(provider, *(src->GetMutable()), *(dst->GetMutable()))); - return S_OK; -#else - return E_NOTIMPL; -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - // Override select shape inference functions which are incomplete in ONNX with versions that are complete, - // and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being - // deferred until first evaluation. It also prevents a situation where inference functions in externally - // registered schema are reachable only after upstream schema have been revised in a later OS release, - // which would be a compatibility risk. - HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override try { -#ifdef USE_DML - static std::once_flag schema_override_once_flag; - std::call_once(schema_override_once_flag, []() { - SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); - }); - return S_OK; -#else - return S_OK; // needs to return S_OK otherwise everything breaks because this gets called from the learningmodel constructor -#endif USE_DML - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo( - onnxruntime::IExecutionProvider* provider, - OrtMemoryInfo** memory_info) override try { - auto allocator = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); - - const auto& info = allocator->Info(); - *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); - if (*memory_info == nullptr) { - return E_OUTOFMEMORY; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue* ort_value, OrtMemoryInfo** memory_info) override try { - const auto& tensor = ort_value->Get(); - auto info = tensor.Location(); - *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); - if (*memory_info == nullptr) { - return E_OUTOFMEMORY; - } - return S_OK; - } - WINMLA_CATCH_ALL_COM - - struct AllocatorWrapper : public OrtAllocator { - public: - AllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) { - version = ORT_API_VERSION; - Alloc = AllocImpl; - Free = FreeImpl; - Info = InfoImpl; - } - - static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { - return static_cast(this_)->impl_->Alloc(size); - } - static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { - return static_cast(this_)->impl_->Free(p); - } - static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { - return &(static_cast(this_)->impl_->Info()); - } - - private: - onnxruntime::AllocatorPtr impl_; - }; - - HRESULT STDMETHODCALLTYPE GetProviderAllocator( - onnxruntime::IExecutionProvider* provider, - OrtAllocator** allocator) override try { - auto allocator_ptr = provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); - *allocator = new (std::nothrow) AllocatorWrapper(allocator_ptr); - if (*allocator == nullptr) { - return E_OUTOFMEMORY; - } - - return S_OK; - } - WINMLA_CATCH_ALL_COM - - HRESULT STDMETHODCALLTYPE FreeProviderAllocator( - OrtAllocator* allocator) override try { - delete static_cast(allocator); - return S_OK; - } - WINMLA_CATCH_ALL_COM -}; // namespace Windows::AI::MachineLearning::Adapter -std::shared_ptr WinMLAdapter::lotus_environment_ = nullptr; - -extern "C" HRESULT STDMETHODCALLTYPE OrtGetWinMLAdapter(IWinMLAdapter** adapter) try { - // make an adapter instance - Microsoft::WRL::ComPtr adapterptr = wil::MakeOrThrow(); - return adapterptr.CopyTo(__uuidof(IWinMLAdapter), reinterpret_cast(adapter)); -} -WINMLA_CATCH_ALL_COM - -// InferenceSession -// ================ - -InferenceSession::InferenceSession(onnxruntime::InferenceSession* session) : session_(session) { -} - -void STDMETHODCALLTYPE InferenceSession::RegisterGraphTransformers() try { -#ifdef USE_DML - // Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT - GraphTransformerHelpers::RegisterGraphTransformers(session_.get()); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -HRESULT STDMETHODCALLTYPE InferenceSession::StartProfiling() try { - this->session_->StartProfiling(PheonixSingleton()->GetDefaultLogger()); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE InferenceSession::EndProfiling() try { - this->session_->EndProfiling(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE -InferenceSession::LoadModel( - IModelProto* model_proto) try { - auto session_protected_load_accessor = - static_cast(session_.get()); - // session's like to have their very own copy of the model_proto, use detach() - std::unique_ptr model_proto_ptr(model_proto->detach()); - ORT_THROW_IF_ERROR(session_protected_load_accessor->Load(std::move(model_proto_ptr))); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -HRESULT STDMETHODCALLTYPE -InferenceSession::RegisterCustomRegistry( - IMLOperatorRegistry* registry) try { - RETURN_HR_IF(S_OK, registry == nullptr); - -#ifdef USE_DML - auto custom_registries = GetLotusCustomRegistries(registry); - - // Register - for (auto& custom_registry : custom_registries) { - ORT_THROW_IF_ERROR(session_->RegisterCustomRegistry(custom_registry)); - } -#endif USE_DML - - return S_OK; -} -WINMLA_CATCH_ALL_COM - -void STDMETHODCALLTYPE InferenceSession::FlushContext(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::FlushContext(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -void STDMETHODCALLTYPE InferenceSession::TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::TrimUploadHeap(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -void STDMETHODCALLTYPE InferenceSession::ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) try { -#ifdef USE_DML - Dml::ReleaseCompletedReferences(dml_provider); -#endif USE_DML -} -WINMLA_CATCH_ALL_DONOTHING - -HRESULT STDMETHODCALLTYPE InferenceSession::CopyOneInputAcrossDevices( - const char* input_name, - const OrtValue* orig_mlvalue, - OrtValue** new_mlvalue) try { - auto session_protected_load_accessor = - static_cast(session_.get()); - const onnxruntime::SessionState& sessionState = session_protected_load_accessor->GetSessionState(); - auto temp_mlvalue = std::make_unique(); - ORT_THROW_IF_ERROR(onnxruntime::utils::CopyOneInputAcrossDevices(sessionState, input_name, *orig_mlvalue, *temp_mlvalue.get())); - *new_mlvalue = temp_mlvalue.release(); - return S_OK; -} -WINMLA_CATCH_ALL_COM - -} // namespace Windows::AI::MachineLearning::Adapter \ No newline at end of file diff --git a/winml/adapter/WinMLAdapter.h b/winml/adapter/WinMLAdapter.h deleted file mode 100644 index 6062e1e62f45a..0000000000000 --- a/winml/adapter/WinMLAdapter.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "core/session/onnxruntime_c_api.h" - -namespace Windows::AI::MachineLearning::Adapter { -TRACELOGGING_DECLARE_PROVIDER(winml_trace_logging_provider); - -MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ - // model metadata - virtual const char* STDMETHODCALLTYPE author() = 0; - virtual const char* STDMETHODCALLTYPE name() = 0; - virtual const char* STDMETHODCALLTYPE domain() = 0; - virtual const char* STDMETHODCALLTYPE description() = 0; - virtual int64_t STDMETHODCALLTYPE version() = 0; - virtual HRESULT STDMETHODCALLTYPE GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView ** metadata) = 0; - virtual HRESULT STDMETHODCALLTYPE GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView * *features) = 0; - virtual HRESULT STDMETHODCALLTYPE GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView * *features) = 0; -}; - -MIDL_INTERFACE("a848faf6-5a2e-4a7f-b622-cc036f71e28a") IModelProto : IUnknown{ - // this returns a weak ref - virtual onnx::ModelProto* STDMETHODCALLTYPE get() = 0; - // this returns the ownership without touching the reference and forgets about the object - virtual onnx::ModelProto* STDMETHODCALLTYPE detach() = 0; -}; - -MIDL_INTERFACE("6ec766ef-6365-42bf-b64f-ae85c015adb8") IInferenceSession : IUnknown { - virtual onnxruntime::InferenceSession* STDMETHODCALLTYPE get() = 0; - // the below returns a weak ref , DO NOT RELEASE IT - virtual HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) = 0; - virtual void STDMETHODCALLTYPE RegisterGraphTransformers() = 0; - virtual HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry * registry) = 0; - virtual HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE StartProfiling() = 0; - virtual HRESULT STDMETHODCALLTYPE EndProfiling() = 0; - virtual void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider * dml_provider) = 0; - virtual void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) = 0; - virtual void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) = 0; - virtual HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name, - const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) = 0; -}; - -// The IOrtSessionBuilder offers an abstraction over the creation of -// InferenceSession, that enables the creation of the session based on a device (CPU/DML). -MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") IOrtSessionBuilder : IUnknown { - - virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( - OrtSessionOptions ** options) = 0; - - virtual HRESULT STDMETHODCALLTYPE CreateSession( - OrtSessionOptions * options, - IInferenceSession** session, - onnxruntime::IExecutionProvider** provider) = 0; - - virtual HRESULT STDMETHODCALLTYPE Initialize( - IInferenceSession* session, - onnxruntime::IExecutionProvider* provider) = 0; -}; - - -MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown { - - virtual void STDMETHODCALLTYPE EnableDebugOutput() = 0; - - virtual HRESULT STDMETHODCALLTYPE EnsureModelDeviceCompatibility( - winml::LearningModel const& model, - IModelProto* p_model_proto, - bool is_float16_supported) = 0; - - // factory method for creating an ortsessionbuilder from a device - virtual HRESULT STDMETHODCALLTYPE CreateOrtSessionBuilder( - ID3D12Device* device, - ID3D12CommandQueue* queue, - IOrtSessionBuilder** session_builder) = 0; - - // factory methods for creating model protos - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0; - virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0; - - // Data types - - // custom ops - virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0; - virtual HRESULT STDMETHODCALLTYPE GetOperatorRegistry(ILearningModelOperatorProviderNative * operator_provider_native, IMLOperatorRegistry * *registry) = 0; - - // dml ep hooks - virtual void* STDMETHODCALLTYPE CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) = 0; - virtual void STDMETHODCALLTYPE FreeGPUAllocation(void* ptr) = 0; - virtual HRESULT STDMETHODCALLTYPE CopyTensor(onnxruntime::IExecutionProvider* provider, OrtValue* src, OrtValue* dst) = 0; - // note: this returns a weak ref - virtual ID3D12Resource* STDMETHODCALLTYPE GetD3D12ResourceFromAllocation(onnxruntime::IExecutionProvider * provider, void* allocation) = 0; - - // schema overrides (dml does this for us) - virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0; - - // proposed adapter. uses the cross plat ABI currencies - virtual HRESULT STDMETHODCALLTYPE GetProviderMemoryInfo(onnxruntime::IExecutionProvider * provider, OrtMemoryInfo** memory_info) = 0; - virtual HRESULT STDMETHODCALLTYPE GetProviderAllocator(onnxruntime::IExecutionProvider * provider, OrtAllocator** allocator) = 0; - virtual HRESULT STDMETHODCALLTYPE FreeProviderAllocator(OrtAllocator* allocator) = 0; - virtual HRESULT STDMETHODCALLTYPE GetValueMemoryInfo(const OrtValue * value, OrtMemoryInfo** memory_info) = 0; - virtual HRESULT STDMETHODCALLTYPE GetMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0; - virtual HRESULT STDMETHODCALLTYPE GetVectorMapType(const OrtValue * ort_value, ONNXTensorElementDataType * key_type, ONNXTensorElementDataType * value_type) = 0; - //virtual HRESULT STDMETHODCALLTYPE CreateTensorFromMap(IInspectable * map, OrtValue * *ort_value) = 0; - //virtual HRESULT STDMETHODCALLTYPE CreateTensorFromSequence(IInspectable * sequence, OrtValue * *ort_value) = 0; -}; - -class InferenceSession : public Microsoft::WRL::RuntimeClass < - Microsoft::WRL::RuntimeClassFlags, - IInferenceSession> { - -public: - - InferenceSession(onnxruntime::InferenceSession * session); - - onnxruntime::InferenceSession* STDMETHODCALLTYPE get() noexcept override { - return session_.get(); - } - - HRESULT STDMETHODCALLTYPE GetOrtSession(OrtSession ** out) noexcept override { - // (OrtSession *) are really (InferenceSession *) as well - *out = reinterpret_cast(session_.get()); - return S_OK; - } - - void STDMETHODCALLTYPE RegisterGraphTransformers() override; - HRESULT STDMETHODCALLTYPE RegisterCustomRegistry(IMLOperatorRegistry* registry) override; - HRESULT STDMETHODCALLTYPE LoadModel(IModelProto* model_proto) override; - HRESULT STDMETHODCALLTYPE StartProfiling() override; - HRESULT STDMETHODCALLTYPE EndProfiling() override; - void STDMETHODCALLTYPE FlushContext(onnxruntime::IExecutionProvider* dml_provider) override; - void STDMETHODCALLTYPE TrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider) override; - void STDMETHODCALLTYPE ReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider) override; - HRESULT STDMETHODCALLTYPE CopyOneInputAcrossDevices(const char* input_name, - const OrtValue* orig_mlvalue, OrtValue** new_mlvalue) override; - - -private: - std::shared_ptr session_; -}; - -} // namespace Windows::AI::MachineLearning::Adapter - -namespace Ort { -// Ort::Allocator is not in the C ABI yet so it will have to be in the WinMLAdapter for now. -// This struct was copied using the Base struct from onnxruntime_cxx_api.h for reference -// Ort::Allocator struct is used as a smart pointer to OrtAllocator. -struct Allocator { - Allocator() { - m_ort_allocator = nullptr; - m_adapter = nullptr; - } - Allocator(winmla::IWinMLAdapter* adapter, OrtAllocator* ort_allocator) : - m_adapter(adapter), m_ort_allocator(ort_allocator) {} - - ~Allocator() { - if (m_adapter != nullptr && m_ort_allocator != nullptr) { - m_adapter->FreeProviderAllocator(m_ort_allocator); - } - } - - operator OrtAllocator*() { return m_ort_allocator; } - operator const OrtAllocator*() const { return m_ort_allocator; } - - OrtAllocator* release() { - OrtAllocator* p = m_ort_allocator; - m_ort_allocator = nullptr; - m_adapter = nullptr; - return p; - } - - OrtAllocator** put() noexcept { - assert(m_ort_allocator == nullptr); - return &m_ort_allocator; - } - - Allocator(const Allocator&) = delete; - Allocator& operator=(const Allocator&) = delete; - Allocator(Allocator&& v) noexcept : - m_adapter{v.m_adapter}, m_ort_allocator{v.m_ort_allocator} { - v.m_adapter = nullptr; - v.m_ort_allocator = nullptr; - } - void operator=(Allocator&& v) noexcept { - if (m_ort_allocator != nullptr && m_adapter != nullptr) { - m_adapter->FreeProviderAllocator(m_ort_allocator); - } - m_adapter = v.m_adapter; - m_ort_allocator = v.m_ort_allocator; - v.m_adapter = nullptr; - v.m_ort_allocator = nullptr; - } - - private: - winmla::IWinMLAdapter* m_adapter; - OrtAllocator* m_ort_allocator; -}; -} // namespace Ort \ No newline at end of file diff --git a/winml/adapter/WinMLAdapterErrors.h b/winml/adapter/WinMLAdapterErrors.h deleted file mode 100644 index 5513842761422..0000000000000 --- a/winml/adapter/WinMLAdapterErrors.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "core/common/status.h" - -inline __declspec(noinline) winrt::hresult_error _winmla_to_hresult() noexcept { - try { - throw; - } catch (winrt::hresult_error const& e) { - return e; - } catch (wil::ResultException const& e) { - return winrt::hresult_error(e.GetErrorCode(), winrt::to_hstring(e.what())); - } catch (std::bad_alloc const&) { - return winrt::hresult_error(E_OUTOFMEMORY); - } catch (std::out_of_range const& e) { - return winrt::hresult_out_of_bounds(winrt::to_hstring(e.what())); - } catch (std::invalid_argument const& e) { - return winrt::hresult_invalid_argument(winrt::to_hstring(e.what())); - } catch (onnxruntime::OnnxRuntimeException const& e) { - StatusCode eStatusCode = static_cast(e.GetStatus().Code()); - return winrt::hresult_error(StatusCodeToHRESULT(eStatusCode), winrt::to_hstring(e.GetStatus().ErrorMessage())); - } catch (std::exception const& e) { - return winrt::hresult_error(E_FAIL, winrt::to_hstring(e.what())); - } catch (...) { - return winrt::hresult_error(E_FAIL); - } -} - -#define WINMLA_CATCH_ALL \ - catch (...) { \ - throw _winmla_to_hresult(); \ - } - -#define WINMLA_CATCH_ALL_COM \ - catch (...) { \ - return _winmla_to_hresult().to_abi(); \ - } - -#define WINMLA_CATCH_ALL_DONOTHING \ - catch (...) { \ - return; \ - } \ No newline at end of file diff --git a/winml/adapter/ZeroCopyInputStreamWrapper.cpp b/winml/adapter/ZeroCopyInputStreamWrapper.cpp deleted file mode 100644 index 1b53326719030..0000000000000 --- a/winml/adapter/ZeroCopyInputStreamWrapper.cpp +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "pch.h" - -#include "ZeroCopyInputStreamWrapper.h" - -#include "winrt/Windows.Foundation.h" - -using namespace Windows::AI::MachineLearning; - -// ZeroCopyInputStreamWrapper -ZeroCopyInputStreamWrapper::ZeroCopyInputStreamWrapper( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream) { - winrt::copy_from_abi(stream_, (void*)stream); -} - -bool ZeroCopyInputStreamWrapper::Next( - const void** data, - int* size) { - if (finished_reading_) { - return false; - } - - auto content = stream_.OpenReadAsync().get(); - - wss::Buffer buffer(static_cast(content.Size())); - auto result = content.ReadAsync( - buffer, - buffer.Capacity(), - wss::InputStreamOptions::None) - .get(); - - bytes_ = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); -#ifdef LAYERING_DONE - WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes_, "Model stream is invalid."); - WINML_THROW_IF_FAILED_MSG( - bytes_->Buffer(reinterpret_cast(const_cast(data))), - "Failed to acquire buffer from model stream."); -#else - bytes_->Buffer(reinterpret_cast(const_cast(data))); -#endif - - *size = static_cast(content.Size()); - finished_reading_ = true; - return true; -} - -// BackUp is used when parsing encounters an error and needs to move -// back to the beginning of the erroneous chunk. We don't support random access, -// so we don't have a pointer to move back, but this can also happen for -// decrypted strings since they can have extra memory at the end that -// isn't valid. We don't want to parse non-model related data so we -// don't support this. I'd like to thrown an error here, but protobuf would -// eat that error and terminate the app. So instead we do nothing and handle -// this in LoadFromStream when the protobuf parsing returns false. -void ZeroCopyInputStreamWrapper::BackUp(int count) { - // purposely do nothing. -} - -// the following methods are required by the interface, -// but they aren't actually used by ModelProto parse code, -bool ZeroCopyInputStreamWrapper::Skip( - int count) { -#ifdef LAYERING_DONE - WINML_THROW_HR(E_NOTIMPL); -#endif - return false; -} - -__int64 -ZeroCopyInputStreamWrapper::ByteCount() const { -#ifdef LAYERING_DONE - WINML_THROW_HR(E_NOTIMPL); -#endif - return 0; -} diff --git a/winml/adapter/ZeroCopyInputStreamWrapper.h b/winml/adapter/ZeroCopyInputStreamWrapper.h deleted file mode 100644 index 8938468317606..0000000000000 --- a/winml/adapter/ZeroCopyInputStreamWrapper.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "winrt/Windows.Storage.Streams.h" -#include - -namespace Windows::AI::MachineLearning { -// _ZeroCopyInputStreamWrapper is a helper class that allows a ZeroCopyInputStream, -// which is a protobuf type, to read from an IRandomAccessStreamReference, which is -// a winrt type. -class ZeroCopyInputStreamWrapper : public google::protobuf::io::ZeroCopyInputStream { - public: - ZeroCopyInputStreamWrapper() = delete; - - ZeroCopyInputStreamWrapper( - ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream); - - // ModelProto load only uses "Next" method - bool - Next( - const void** data, - int* size); - - void - BackUp( - int count); - - bool - Skip( - int count); - - __int64 - ByteCount() const; - - private: - wss::IRandomAccessStreamReference stream_; - bool finished_reading_ = false; - winrt::com_ptr<::Windows::Storage::Streams::IBufferByteAccess> bytes_; -}; - -} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/AbiCustomRegistryImpl.cpp b/winml/adapter/abi_custom_registry_impl.cpp similarity index 98% rename from winml/adapter/AbiCustomRegistryImpl.cpp rename to winml/adapter/abi_custom_registry_impl.cpp index 7242ca121ffe5..00b20cba1b95f 100644 --- a/winml/adapter/AbiCustomRegistryImpl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -5,7 +5,7 @@ #ifdef USE_DML -#include "AbiCustomRegistryImpl.h" +#include "abi_custom_registry_impl.h" namespace Windows::AI::MachineLearning::Adapter { diff --git a/winml/adapter/AbiCustomRegistryImpl.h b/winml/adapter/abi_custom_registry_impl.h similarity index 93% rename from winml/adapter/AbiCustomRegistryImpl.h rename to winml/adapter/abi_custom_registry_impl.h index a07f51cacd067..77b8cba2897d4 100644 --- a/winml/adapter/AbiCustomRegistryImpl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -6,7 +6,7 @@ #ifdef USE_DML #include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" -namespace Windows::AI::MachineLearning::Adapter{ +namespace Windows::AI::MachineLearning::Adapter { // An implementation of AbiCustomRegistry that emits telemetry events when operator kernels or schemas are registered. class AbiCustomRegistryImpl : public AbiCustomRegistry { @@ -38,5 +38,5 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { _In_opt_ IMLOperatorShapeInferrer* shape_inferrer) const noexcept override; }; -} // namespace winrt::Windows::AI::MachineLearning::Adapter +} // namespace Windows::AI::MachineLearning::Adapter #endif USE_DML diff --git a/winml/adapter/winml_adapter_apis.h b/winml/adapter/winml_adapter_apis.h new file mode 100644 index 0000000000000..1c33d5393ef47 --- /dev/null +++ b/winml/adapter/winml_adapter_apis.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "winml_adapter_c_api.h" + +namespace Windows { +namespace AI { +namespace MachineLearning { +namespace Adapter { + +ORT_API(void, ReleaseModel, OrtModel*); +ORT_API(void, ReleaseExecutionProvider, OrtExecutionProvider*); + +ORT_API_STATUS(OverrideSchema); + +// OrtEnv methods +ORT_API_STATUS(EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + +// OrtModel methods +ORT_API_STATUS(CreateModelFromPath, _In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out); +ORT_API_STATUS(CreateModelFromData, _In_ void* data, _In_ size_t size, _Outptr_ OrtModel** out); +ORT_API_STATUS(CloneModel, _In_ const OrtModel* in, _Outptr_ OrtModel** out); +ORT_API_STATUS(ModelGetAuthor, _In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len); +ORT_API_STATUS(ModelGetName, _In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len); +ORT_API_STATUS(ModelGetDomain, _In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len); +ORT_API_STATUS(ModelGetDescription, _In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len); +ORT_API_STATUS(ModelGetVersion, _In_ const OrtModel* model, _Out_ int64_t* version); +ORT_API_STATUS(ModelGetInputCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputName, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_name, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputName, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_name, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputDescription, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_description, _Out_ size_t* count); +ORT_API_STATUS(ModelGetOutputDescription, _In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_description, _Out_ size_t* count); +ORT_API_STATUS(ModelGetInputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS(ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info); +ORT_API_STATUS(ModelGetMetadataCount, _In_ const OrtModel* model, _Out_ size_t* count); +ORT_API_STATUS(ModelGetMetadata, _In_ const OrtModel* model, _Out_ size_t count, _Out_ const char** const key, _Out_ size_t* key_len, _Out_ const char** const value, _Out_ size_t* value_len); +ORT_API_STATUS(ModelEnsureNoFloat16, _In_ const OrtModel* model); + +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* cmd_queue); + +// OrtSession methods +ORT_API_STATUS(CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session); + +//Do not release provider... as there is no release method available +ORT_API_STATUS(SessionGetExecutionProvider, _In_ OrtSession* session, size_t index, _Out_ OrtExecutionProvider** provider); +ORT_API_STATUS(SessionInitialize, _In_ OrtSession* session); +ORT_API_STATUS(SessionLoadAndPurloinModel, _In_ OrtSession* session, _In_ OrtModel* model); + +ORT_API_STATUS(SessionStartProfiling, _In_ OrtEnv* env, _In_ OrtSession* session); +ORT_API_STATUS(SessionEndProfiling, _In_ OrtSession* session); +ORT_API_STATUS(SessionRegisterGraphTransformers, _In_ OrtSession* session); +ORT_API_STATUS(SessionRegisterCustomRegistry, _In_ OrtSession* session, _In_ IMLOperatorRegistry* registry); +ORT_API_STATUS(SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_ const char* const input_name, _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value); + +// Dml methods (TODO need to figure out how these need to move to session somehow...) +ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled); +ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider); +ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource); +ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource); +ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr); + +// note: this returns a weak ref + +ORT_API_STATUS(GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info); +ORT_API_STATUS(GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator); +ORT_API_STATUS(FreeProviderAllocator, _In_ OrtAllocator* allocator); +ORT_API_STATUS(GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info); + +// ExecutionProvider Methods +ORT_API_STATUS(ExecutionProviderSync, _In_ OrtExecutionProvider* provider); +ORT_API_STATUS(DmlCopyTensor, _In_ OrtExecutionProvider* provider, _In_ OrtValue* src, _In_ OrtValue* dst); +ORT_API_STATUS(CreateCustomRegistry, _Out_ IMLOperatorRegistry** registry); + +ORT_API_STATUS(ValueGetDeviceId, _In_ OrtValue* ort_value, _Out_ int16_t* device_id); +ORT_API_STATUS(SessionGetInputRequiredDeviceId, _In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id); + +} // namespace Adapter +} // namespace MachineLearning +} // namespace AI +} // namespace Windows \ No newline at end of file diff --git a/winml/adapter/winml_adapter_c_api.cpp b/winml/adapter/winml_adapter_c_api.cpp new file mode 100644 index 0000000000000..3ab5645893c8e --- /dev/null +++ b/winml/adapter/winml_adapter_c_api.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "winml_adapter_apis.h" +#include "core/session/ort_apis.h" + +#include +#include +#include + +const OrtApi* GetVersion1Api(); + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +static constexpr WinmlAdapterApi winml_adapter_api_1 = { + // Schema override + &winmla::OverrideSchema, + + // OrtEnv methods + &winmla::EnvConfigureCustomLoggerAndProfiler, + + // OrtTypeInfo Casting methods + &OrtApis::GetDenotationFromTypeInfo, + &OrtApis::CastTypeInfoToMapTypeInfo, + &OrtApis::CastTypeInfoToSequenceTypeInfo, + + // OrtMapTypeInfo Accessors + &OrtApis::GetMapKeyType, + &OrtApis::GetMapValueType, + + // OrtSequenceTypeInfo Accessors + &OrtApis::GetSequenceElementType, + + // OrtModel methods + &winmla::CreateModelFromPath, + &winmla::CreateModelFromData, + &winmla::CloneModel, + &winmla::ModelGetAuthor, + &winmla::ModelGetName, + &winmla::ModelGetDomain, + &winmla::ModelGetDescription, + &winmla::ModelGetVersion, + &winmla::ModelGetInputCount, + &winmla::ModelGetOutputCount, + &winmla::ModelGetInputName, + &winmla::ModelGetOutputName, + &winmla::ModelGetInputDescription, + &winmla::ModelGetOutputDescription, + &winmla::ModelGetInputTypeInfo, + &winmla::ModelGetOutputTypeInfo, + &winmla::ModelGetMetadataCount, + &winmla::ModelGetMetadata, + &winmla::ModelEnsureNoFloat16, + + // OrtSessionOptions methods + &OrtSessionOptionsAppendExecutionProvider_CPU, + &winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, + + // OrtSession methods + &winmla::CreateSessionWithoutModel, + &winmla::SessionGetExecutionProvider, + &winmla::SessionInitialize, + &winmla::SessionRegisterGraphTransformers, + &winmla::SessionRegisterCustomRegistry, + &winmla::SessionLoadAndPurloinModel, + &winmla::SessionStartProfiling, + &winmla::SessionEndProfiling, + &winmla::SessionCopyOneInputAcrossDevices, + + // Dml methods (TODO need to figure out how these need to move to session somehow...) + &winmla::DmlExecutionProviderSetDefaultRoundingMode, + &winmla::DmlExecutionProviderFlushContext, + &winmla::DmlExecutionProviderTrimUploadHeap, + &winmla::DmlExecutionProviderReleaseCompletedReferences, + &winmla::DmlCreateGPUAllocationFromD3DResource, + &winmla::DmlFreeGPUAllocation, + &winmla::DmlGetD3D12ResourceFromAllocation, + &winmla::DmlCopyTensor, + + &winmla::GetProviderMemoryInfo, + &winmla::GetProviderAllocator, + &winmla::FreeProviderAllocator, + &winmla::GetValueMemoryInfo, + + &winmla::ExecutionProviderSync, + + &winmla::CreateCustomRegistry, + + &winmla::ValueGetDeviceId, + &winmla::SessionGetInputRequiredDeviceId, + + // Release + &winmla::ReleaseModel, + &OrtApis::ReleaseMapTypeInfo, + &OrtApis::ReleaseSequenceTypeInfo}; + +const WinmlAdapterApi* ORT_API_CALL OrtGetWinMLAdapter(const OrtApi* ort_api) NO_EXCEPTION { + if (GetVersion1Api() == ort_api) { + return &winml_adapter_api_1; + } + + return nullptr; +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_c_api.h b/winml/adapter/winml_adapter_c_api.h new file mode 100644 index 0000000000000..7f2e17259e0be --- /dev/null +++ b/winml/adapter/winml_adapter_c_api.h @@ -0,0 +1,469 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(ExecutionProvider); + +struct WinmlAdapterApi; +typedef struct WinmlAdapterApi WinmlAdapterApi; + +struct ID3D12Resource; +struct ID3D12Device; +struct ID3D12CommandQueue; +struct IMLOperatorRegistry; + +// TODO: Must match onnxruntime::profiling::EventRecord +enum OrtProfilerEventCategory { + SESSION_EVENT = 0, + NODE_EVENT, + EVENT_CATEGORY_MAX +}; + +struct OrtProfilerEventRecord { + OrtProfilerEventCategory category_; + const char* category_name_; + int64_t duration_; + int64_t time_span_; + const char* event_name_; + int32_t process_id_; + int32_t thread_id_; + const char* op_name_; + const char* execution_provider_; +}; + +typedef void(ORT_API_CALL* OrtProfilingFunction)(const OrtProfilerEventRecord* event_record); + +struct WinmlAdapterApi { + /** + * OverrideSchema + * This api is used to override schema inference functions for a variety of ops across opsets. + * This exists because certain ops were failing to infer schemas and caused performance + * issues for DML as it was forced to create resources during evaluation. + * This can be removed when schema inference functions have been updated. + */ + OrtStatus*(ORT_API_CALL* OverrideSchema)() NO_EXCEPTION; + + /** + * EnvConfigureCustomLoggerAndProfiler + * This api is used to add a custom logger and profiler to the ors environment. + * This exists because existing methods on the c-abi to create the environment only support a custom logger. + * Since WinML hooks the profiler events, we expose the profiler and an associated profiling function. + */ + OrtStatus*(ORT_API_CALL* EnvConfigureCustomLoggerAndProfiler)(_In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out)NO_EXCEPTION; + + /** + * GetDenotationFromTypeInfo + * This api augments OrtTypeInfo to return denotations on the type. + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + */ + OrtStatus*(ORT_API_CALL* GetDenotationFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len)NO_EXCEPTION; + + // OrtTypeInfo Casting methods + + /** + * CastTypeInfoToMapTypeInfo + * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + OrtStatus*(ORT_API_CALL* CastTypeInfoToMapTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out)NO_EXCEPTION; + + /** + * CastTypeInfoToSequenceTypeInfo + * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence. + * The OrtSequenceTypeInfo has additional information about the sequence's element type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + OrtStatus*(ORT_API_CALL* CastTypeInfoToSequenceTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out)NO_EXCEPTION; + + // OrtMapTypeInfo Accessors + + /** + * GetMapKeyType + * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* GetMapKeyType)(_In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; + + /** + * GetMapValueType + * This api augments get the value type of a map. + */ + OrtStatus*(ORT_API_CALL* GetMapValueType)(_In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + // OrtSequenceTypeInfo Accessors + + /** + * GetSequenceElementType + * This api augments get the element type of a sequence. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* GetSequenceElementType)(_In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + // OrtModel methods + + /** + * CreateModelFromPath + * This api creates an OrtModel based on a specified model path. + * There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* CreateModelFromPath)(_In_ const char* model_path, _In_ size_t size, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * CreateModelFromData + * This api creates an OrtModel from a buffer. + * There is no inferencing or evaluation setup performed. Only ONNX load is done to reflect on the model's inputs/outputs and other properties. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* CreateModelFromData)(_In_ void* data, _In_ size_t size, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * CloneModel + * This api copies the OrtModel along with its internal proto buffer and cached metadata. + * The OrtSession type expects to own the model proto buffer. + * WinML uses this to yield copies of the model proto held by OrtModel to OrtSession. + */ + OrtStatus*(ORT_API_CALL* CloneModel)(_In_ const OrtModel* in, _Outptr_ OrtModel** out)NO_EXCEPTION; + + /** + * ModelGetAuthor + * This api gets the model author from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetAuthor)(_In_ const OrtModel* model, _Out_ const char** const author, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetName + * This api gets the model name from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetName)(_In_ const OrtModel* model, _Out_ const char** const name, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetDomain + * This api gets the model domain from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetDomain)(_In_ const OrtModel* model, _Out_ const char** const domain, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetDescription + * This api gets the model description from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetDescription)(_In_ const OrtModel* model, _Out_ const char** const description, _Out_ size_t* len)NO_EXCEPTION; + + /** + * ModelGetVersion + * This api gets the model version from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetVersion)(_In_ const OrtModel* model, _Out_ int64_t* version)NO_EXCEPTION; + + /** + * ModelGetInputCount + * This api gets the number of inputs from the OrtModel. It closely matches the API of a similar name similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputCount + * This api gets the number of outputs from the OrtModel. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputName + * This api gets the input name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputName)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_name, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputName + * This api gets the output name from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputName)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_name, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputDescription + * This api gets the input description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputDescription)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** input_description, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetOutputDescription + * This api gets the output description from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputDescription)(_In_ const OrtModel* model, _In_ size_t index, _Out_ const char** output_description, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetInputTypeInfo + * This api gets the input OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetInputTypeInfo)(_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + /** + * ModelGetOutputTypeInfo + * This api gets the output OrtTypeInfo from the OrtModel given an index. It closely matches the API of a similar name for retrieving model metadata from OrtSession. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetOutputTypeInfo)(_In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + + /** + * ModelGetMetadataCount + * This api gets the number of metadata entries from the OrtModel. + * This is used by WinML to support model reflection APIs. + */ + OrtStatus*(ORT_API_CALL* ModelGetMetadataCount)(_In_ const OrtModel* model, _Out_ size_t* count)NO_EXCEPTION; + + /** + * ModelGetMetadata + * This api gets the model metadata from the OrtModel. + * This is used by WinML to deduce whether model input and output formats are supported by the WinML tensorization code paths. + */ + OrtStatus*(ORT_API_CALL* ModelGetMetadata)(_In_ const OrtModel* model, _Out_ size_t count, _Out_ const char** const key, _Out_ size_t* key_len, _Out_ const char** const value, _Out_ size_t* value_len)NO_EXCEPTION; + + /** + * ModelEnsureNoFloat16 + * This api checks whether the model requires float 16 support. + * This is used by WinML to fail gracefully when float 16 support is not available on the device. + * + * Can this API be moved into the EP during session initialization. Currently we do an early fp16 check to avoid initialization when it is not supported. + */ + OrtStatus*(ORT_API_CALL* ModelEnsureNoFloat16)(_In_ const OrtModel* model)NO_EXCEPTION; + + // OrtSessionOptions methods + + /** + * OrtSessionOptionsAppendExecutionProvider_CPU + * This api is used to add the cpu EP to OrtSessionOptions so that WinML Gpu session are configures with CPU fallback. + */ + OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_CPU)(_In_ OrtSessionOptions* options, int use_arena)NO_EXCEPTION; + + /** + * OrtSessionOptionsAppendExecutionProvider_DML + * This api is used to add the DML EP to OrtSessionOptions. + */ + OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_DML)(_In_ OrtSessionOptions* options, ID3D12Device* device, ID3D12CommandQueue* queue)NO_EXCEPTION; + + // OrtSession methods + + /** + * CreateSessionWithoutModel + * This api is used to create a Session that is completely uninitialized. While there are other Session creation APIs in the + * c-abi, WinML uses this so that it can perform optimizations prior to loading the model, and initializing. + * Moreover, WinML needs a new api to support the OrtModel type, and prevent the parsing model protobufs again on session creation. + */ + OrtStatus*(ORT_API_CALL* CreateSessionWithoutModel)(_In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session)NO_EXCEPTION; + + /** + * SessionGetExecutionProvider + * This api is used to get a handle to an OrtExecutionProvider. + * Currently WinML uses this to talk directly to the DML EP and configure settings on it. + */ + OrtStatus*(ORT_API_CALL* SessionGetExecutionProvider)(_In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** provider)NO_EXCEPTION; + + /** + * SessionInitialize + * This api is used to initialize an OrtSession. This is one component of creating a usable OrtSession, and is a part of CreateSession in the c-abi. + * Currently WinML uses this to finalize session creation, after configuring a variety of properties on the OrtSession. + */ + OrtStatus*(ORT_API_CALL* SessionInitialize)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionRegisterGraphTransformers + * This api is used to enable DML specific graph transformations on an OrtSession. + * + * Ideally these transformations should be configured by the contract between the runtime and the EP and not overridden by WinML. + */ + OrtStatus*(ORT_API_CALL* SessionRegisterGraphTransformers)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionRegisterCustomRegistry + * This api is used to support custom operators as they were shipped in WinML RS5. + */ + OrtStatus*(ORT_API_CALL* SessionRegisterCustomRegistry)(_In_ OrtSession* session, _In_ IMLOperatorRegistry* registry)NO_EXCEPTION; + + /** + * SessionLoadAndPurloinModel + * This api is used to load an OrtModel into an OrtSession. + * + * Don't free the 'out' value as this API will defunct and release the OrtModel internally. + */ + OrtStatus*(ORT_API_CALL* SessionLoadAndPurloinModel)(_In_ OrtSession* session, _In_ OrtModel* model)NO_EXCEPTION; + + /** + * SessionStartProfiling + * This api is used to start profiling OrtSession. The existing mechanism only allows configuring profiling at session creation. + * + * WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to. + */ + OrtStatus*(ORT_API_CALL* SessionStartProfiling)(_In_ OrtEnv* env, _In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionEndProfiling + * This api is used to end profiling OrtSession. The existing mechanism only allows configuring profiling at session creation. + * + * WinML uses this to toggle profilling on and off based on if a telemetry providers are being listened to. + */ + OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* session)NO_EXCEPTION; + + /** + * SessionCopyOneInputAcrossDevices + * This api is used to copy and create an OrtValue input to prepare the input on the correct device. + * + * WinML uses this to copy gpu device OrtValues to the CPU and vice-versa. + */ + OrtStatus*(ORT_API_CALL* SessionCopyOneInputAcrossDevices)(_In_ OrtSession* session, _In_ const char* const input_name, _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value)NO_EXCEPTION; + + // Dml methods (TODO need to figure out how these need to move to session somehow...) + + /** + * DmlExecutionProviderSetDefaultRoundingMode + * This api is used to configure the DML EP to turn on/off rounding. + * + * WinML uses this to disable rounding during session initialization and then enables it again post initialization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled)NO_EXCEPTION; + + /** + * DmlExecutionProviderFlushContext + * This api is used to flush the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlExecutionProviderTrimUploadHeap + * This api is used to trim the upload heap in the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderTrimUploadHeap)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlExecutionProviderReleaseCompletedReferences + * This api is used to release completed references after first run the DML EP. + * + * WinML communicates directly with DML to perform this as an optimization. + */ + OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION; + + /** + * DmlCreateGPUAllocationFromD3DResource + * This api is used to create a DML EP input based on a user specified d3d12 resource. + * + * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. + */ + OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void** dml_resource)NO_EXCEPTION; + + /** + * DmlFreeGPUAllocation + * This api is used free the DML EP input created by DmlCreateGPUAllocationFromD3DResource. + * + * WinML uses this as part of its Tensor apis to allow callers to specify their own D3D12 resources as inputs/outputs. + */ + OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION; + + /** + * DmlGetD3D12ResourceFromAllocation + * This api is used to get the D3D12 resource when a OrtValue has been allocated by the DML EP and accessed via GetMutableTensorData. + * + * WinML uses this in the image feature path to get the d3d resource and perform and tensorization on inputs directly into the allocated d3d12 resource. + */ + OrtStatus*(ORT_API_CALL* DmlGetD3D12ResourceFromAllocation)(_In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource)NO_EXCEPTION; + + /** + * DmlCopyTensor + * This api is used copy a tensor allocated by the DML EP Allocator to the CPU. + * + * WinML uses this when graphs are evaluated with DML, and their outputs remain on the GPU but need to be copied back to the CPU. + */ + OrtStatus*(ORT_API_CALL* DmlCopyTensor)(_In_ OrtExecutionProvider* provider, _In_ OrtValue* src, _In_ OrtValue* dst)NO_EXCEPTION; + + /** + * GetProviderMemoryInfo + * This api gets the memory info object associated with an EP. + * + * WinML uses this to manage caller specified D3D12 inputs/outputs. It uses the memory info here to call DmlCreateGPUAllocationFromD3DResource. + */ + OrtStatus*(ORT_API_CALL* GetProviderMemoryInfo)(_In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info)NO_EXCEPTION; + + /** + * GetProviderAllocator + * This api gets associated allocator used by a provider. + * + * WinML uses this to create tensors, and needs to hold onto the allocator for the duration of the associated value's lifetime. + */ + OrtStatus*(ORT_API_CALL* GetProviderAllocator)(_In_ OrtExecutionProvider* provider, OrtAllocator** allocator)NO_EXCEPTION; + + /** + * FreeProviderAllocator + * This api frees an allocator. + * + * WinML uses this to free the associated allocator for an ortvalue when creating tensors. + * Internally this derefs a shared_ptr. + */ + OrtStatus*(ORT_API_CALL* FreeProviderAllocator)(_In_ OrtAllocator* allocator)NO_EXCEPTION; + + /** + * GetValueMemoryInfo + * This api gets the memory info of an OrtValue. + * + * WinML uses this to determine if an OrtValue is allocated on the Cpu or elsewhere. + */ + OrtStatus*(ORT_API_CALL* GetValueMemoryInfo)(const OrtValue* value, OrtMemoryInfo** memory_info)NO_EXCEPTION; + + /** + * ExecutionProviderSync + * This api syncs the EP. + * + * WinML uses this to sync EP inputs/outputs directly. + */ + OrtStatus*(ORT_API_CALL* ExecutionProviderSync)(_In_ OrtExecutionProvider* provider)NO_EXCEPTION; + + /** + * CreateCustomRegistry + * This api creates a custom registry that callers can populate with cusom ops. + * + * WinML uses this to support custom ops. + */ + OrtStatus*(ORT_API_CALL* CreateCustomRegistry)(_Out_ IMLOperatorRegistry** registry)NO_EXCEPTION; + + /** + * ValueGetDeviceId + * This api returns the device id of the OrtValue. + * + * WinML uses this to determine if an OrtValue is created on the needed device. + */ + OrtStatus*(ORT_API_CALL* ValueGetDeviceId)(_In_ OrtValue* ort_value, _Out_ int16_t* device_id)NO_EXCEPTION; + + /** + * SessionGetInputRequiredDeviceId + * This api returns the required device id for a model input. + * + * WinML uses this to determine if an OrtValue is created on the needed device. + */ + OrtStatus*(ORT_API_CALL* SessionGetInputRequiredDeviceId)(_In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id)NO_EXCEPTION; + + ORT_CLASS_RELEASE(Model); + ORT_CLASS_RELEASE(MapTypeInfo); + ORT_CLASS_RELEASE(SequenceTypeInfo); +}; diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp new file mode 100644 index 0000000000000..ddbd03475bc9f --- /dev/null +++ b/winml/adapter/winml_adapter_dml.cpp @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#ifdef USE_DML +#include "core/session/abi_session_options_impl.h" +#include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#endif // USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +#ifdef USE_DML +Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { + // Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll + wil::unique_hmodule dmlDll(LoadLibraryW(L"DirectML.dll")); + THROW_LAST_ERROR_IF(!dmlDll); + + auto dmlCreateDevice1Fn = reinterpret_cast( + GetProcAddress(dmlDll.get(), "DMLCreateDevice1")); + THROW_LAST_ERROR_IF(!dmlCreateDevice1Fn); + + DML_CREATE_DEVICE_FLAGS dmlFlags = DML_CREATE_DEVICE_FLAG_NONE; + + // Enable the DML debug layer in DEBUG builds, if the D3D12 debug layer is also enabled +#if _DEBUG + Microsoft::WRL::ComPtr d3d12DebugDevice; + if (SUCCEEDED(d3d12Device->QueryInterface(IID_PPV_ARGS(&d3d12DebugDevice)))) { + d3d12DebugDevice = nullptr; + dmlFlags |= DML_CREATE_DEVICE_FLAG_DEBUG; + } +#endif // USE_DML + + Microsoft::WRL::ComPtr dmlDevice; + THROW_IF_FAILED(dmlCreateDevice1Fn(d3d12Device, dmlFlags, DML_FEATURE_LEVEL_2_0, IID_PPV_ARGS(&dmlDevice))); + + // Keep DirectML.dll loaded by leaking the handle. This is equivalent behavior to if we delay-loaded the DLL. + dmlDll.release(); + + return dmlDevice; +} + +namespace onnxruntime { +void DmlConfigureProviderFactoryDefaultRoundingMode(onnxruntime::IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode); +} + +#endif // USE_DML + +ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, + ID3D12Device* d3d_device, ID3D12CommandQueue* queue) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_device = CreateDmlDevice(d3d_device); + if (auto status = OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue)) { + return status; + } + auto factory = options->provider_factories.back().get(); + + // OnnxRuntime uses the default rounding mode when calling the session's allocator. + // During initialization, OnnxRuntime allocates weights, which are permanent across session + // lifetime and can be large, so shouldn't be rounded. + // So we create the provider with rounding disabled, and expect the caller to enable it after. + onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::SetDefaultRoundingMode(dml_provider_internal, is_enabled ? AllocatorRoundingMode::Enabled : AllocatorRoundingMode::Disabled); +#endif + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::FlushContext(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::TrimUploadHeap(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + Dml::ReleaseCompletedReferences(dml_provider_internal); +#endif // USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + *dml_resource = Dml::CreateGPUAllocationFromD3DResource(pResource); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* dml_provider, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + *d3d_resource = + Dml::GetD3D12ResourceFromAllocation( + dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(), + allocation); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlFreeGPUAllocation, _In_ void* ptr) { + API_IMPL_BEGIN +#ifdef USE_DML + Dml::FreeGPUAllocation(ptr); +#endif // USE_DML USE_DML + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provider, _In_ OrtValue* src, _In_ OrtValue* dst) { + API_IMPL_BEGIN +#ifdef USE_DML + auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider); + auto status = Dml::CopyTensor(dml_provider_internal, *(src->GetMutable()), *(dst->GetMutable())); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; +#else + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Out of memory"); +#endif // USE_DML USE_DML + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp new file mode 100644 index 0000000000000..4aba907e4cb86 --- /dev/null +++ b/winml/adapter/winml_adapter_environment.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_env.h" + +#ifdef USE_DML +#include "abi_custom_registry_impl.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" +#endif USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +class WinmlAdapterLoggingWrapper : public LoggingWrapper { + public: + WinmlAdapterLoggingWrapper(OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param) : LoggingWrapper(logging_function, logger_param), + profiling_function_(profiling_function) { + } + + void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override { + if (profiling_function_) { + OrtProfilerEventRecord ort_event_record = {}; + ort_event_record.category_ = static_cast(event_record.cat); + ort_event_record.category_name_ = onnxruntime::profiling::event_categor_names_[event_record.cat]; + ort_event_record.duration_ = event_record.dur; + ort_event_record.event_name_ = event_record.name.c_str(); + ort_event_record.execution_provider_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["provider"].c_str() : nullptr; + ort_event_record.op_name_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["op_name"].c_str() : nullptr; + ort_event_record.process_id_ = event_record.pid; + ort_event_record.thread_id_ = event_record.tid; + ort_event_record.time_span_ = event_record.ts; + + profiling_function_(&ort_event_record); + } + } + + private: + OrtProfilingFunction profiling_function_{}; +}; + +ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, + _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, + _In_ const char* logid, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN + std::string name = logid; + std::unique_ptr logger = onnxruntime::make_unique(logging_function, profiling_function, logger_param); + + // Clear the logging manager, since only one default instance of logging manager can exist at a time. + env->SetLoggingManager(nullptr); + + auto winml_logging_manager = std::make_unique(std::move(logger), + static_cast(default_warning_level), + false, + onnxruntime::logging::LoggingManager::InstanceType::Default, + &name); + + // Set a new default logging manager + env->SetLoggingManager(std::move(winml_logging_manager)); + return nullptr; + API_IMPL_END +} + +// Override select shape inference functions which are incomplete in ONNX with versions that are complete, +// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being +// deferred until first evaluation. It also prevents a situation where inference functions in externally +// registered schema are reachable only after upstream schema have been revised in a later OS release, +// which would be a compatibility risk. +ORT_API_STATUS_IMPL(winmla::OverrideSchema) { + API_IMPL_BEGIN +#ifdef USE_DML + static std::once_flag schema_override_once_flag; + std::call_once(schema_override_once_flag, []() { + SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); + }); +#endif USE_DML. + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_execution_provider.cpp b/winml/adapter/winml_adapter_execution_provider.cpp new file mode 100644 index 0000000000000..a38af2af931c3 --- /dev/null +++ b/winml/adapter/winml_adapter_execution_provider.cpp @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +struct OrtAllocatorWrapper : public OrtAllocator { + public: + OrtAllocatorWrapper(onnxruntime::AllocatorPtr impl) : impl_(impl) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + return static_cast(this_)->impl_->Alloc(size); + } + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + return static_cast(this_)->impl_->Free(p); + } + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + return &(static_cast(this_)->impl_->Info()); + } + + private: + onnxruntime::AllocatorPtr impl_; +}; + +ORT_API_STATUS_IMPL(winmla::ExecutionProviderSync, _In_ OrtExecutionProvider* provider) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + execution_provider->Sync(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetProviderAllocator, _In_ OrtExecutionProvider* provider, OrtAllocator** allocator) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + auto allocator_ptr = execution_provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); + *allocator = new (std::nothrow) OrtAllocatorWrapper(allocator_ptr); + if (*allocator == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetProviderMemoryInfo, _In_ OrtExecutionProvider* provider, OrtMemoryInfo** memory_info) { + API_IMPL_BEGIN + const auto execution_provider = reinterpret_cast(provider); + + auto allocator = execution_provider->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault); + + const auto& info = allocator->Info(); + *memory_info = new (std::nothrow) OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); + if (*memory_info == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::FreeProviderAllocator, _In_ OrtAllocator* allocator) { + API_IMPL_BEGIN + delete static_cast(allocator); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::GetValueMemoryInfo, const OrtValue* value, OrtMemoryInfo** memory_info) { + API_IMPL_BEGIN + const auto& tensor = value->Get(); + auto info = tensor.Location(); + *memory_info = new OrtMemoryInfo(info.name, info.alloc_type, info.device, info.id, info.mem_type); + if (*memory_info == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); + } + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp new file mode 100644 index 0000000000000..6e4d22588aec3 --- /dev/null +++ b/winml/adapter/winml_adapter_model.cpp @@ -0,0 +1,429 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_model.h" + +#include "winml_adapter_c_api.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "core/framework/onnxruntime_typeinfo.h" + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +static std::vector GetInitializers(const onnx::ModelProto& model_proto) { + std::vector initializers; + auto& graph = model_proto.graph(); + auto& graph_initializers = graph.initializer(); + for (auto& initializer : graph_initializers) { + initializers.push_back(initializer.name().c_str()); + } + return initializers; +} + +static std::vector GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { + auto initializers = GetInitializers(model_proto); + + std::vector inputs_without_initializers; + auto& graph = model_proto.graph(); + auto& inputs = graph.input(); + for (auto& input : inputs) { + if (input.has_name() && input.has_type()) { + auto found_it = std::find_if( + std::begin(initializers), + std::end(initializers), + [&](auto& initializer) { + return std::strcmp(initializer, input.name().c_str()) == 0; + }); + + auto is_initializer = found_it != std::end(initializers); + if (!is_initializer) { + inputs_without_initializers.push_back(&input); + } + } + } + return inputs_without_initializers; +} + +static std::vector GetOutputs(const onnx::ModelProto& model_proto) { + std::vector outputs_with_name; + auto& graph = model_proto.graph(); + auto& outputs = graph.output(); + for (auto& output : outputs) { + if (output.has_name() && output.has_type()) { + outputs_with_name.push_back(&output); + } + } + return outputs_with_name; +} + +class ModelInfo { + public: + ModelInfo(const onnx::ModelProto* model_proto) { + Initialize(model_proto); + } + + public: + // model metadata + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::vector> model_metadata_; + std::vector input_features_; + std::vector output_features_; + bool requires_float16_support_; + + private: + void Initialize(const onnx::ModelProto* model_proto) { + for (auto& prop : model_proto->metadata_props()) { + model_metadata_.push_back(std::make_pair(prop.key(), prop.value())); + } + + input_features_ = GetInputsWithoutInitializers(*model_proto); + output_features_ = ::GetOutputs(*model_proto); + + auto has_producer_name = model_proto->has_producer_name(); + author_ = has_producer_name ? model_proto->producer_name() : ""; + + auto has_domain = model_proto->has_domain(); + domain_ = has_domain ? model_proto->domain() : ""; + + auto has_graph = model_proto->has_graph(); + auto graph_has_name = model_proto->graph().has_name(); + auto is_name_available = has_graph && graph_has_name; + name_ = is_name_available ? model_proto->graph().name() : ""; + + auto has_description = model_proto->has_doc_string(); + description_ = has_description ? model_proto->doc_string() : ""; + + auto has_version = model_proto->has_model_version(); + version_ = has_version ? model_proto->model_version() : 0; + } +}; + +OrtModel::OrtModel(std::unique_ptr model_proto) : model_proto_(std::move(model_proto)), + model_info_(std::make_unique(model_proto_.get())) { +} + +// factory methods for creating an ort model from a path +static OrtStatus* CreateModelProto(const char* path, std::unique_ptr& out) { + int file_descriptor; + _set_errno(0); // clear errno + _sopen_s( + &file_descriptor, + path, + O_RDONLY | _O_SEQUENTIAL | _O_BINARY, + _SH_DENYWR, + _S_IREAD | _S_IWRITE); + + errno_t err = 0; + _get_errno(&err); + if (err == ENOENT) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!"); + } + + if (0 > file_descriptor) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, "Model file not found!"); + } + + google::protobuf::io::FileInputStream stream(file_descriptor); + stream.SetCloseOnDelete(true); + + auto model_proto = std::unique_ptr(new onnx::ModelProto()); + + auto parse_succeeded = model_proto->ParseFromZeroCopyStream(&stream); + if (!parse_succeeded) { + return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model file!"); + } + + out = std::move(model_proto); + + return S_OK; +} + +OrtStatus* OrtModel::CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model) { + ORT_UNUSED_PARAMETER(len); + + std::unique_ptr model_proto; + + if (auto status = CreateModelProto(path, model_proto)) { + return status; + } + + return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); +} + +OrtStatus* OrtModel::CreateOrtModelFromData(void* data, size_t len, OrtModel** model) { + auto model_proto = std::unique_ptr(new onnx::ModelProto()); + + auto parse_succeeded = model_proto->ParseFromArray(data, static_cast(len)); + if (!parse_succeeded) { + return OrtApis::CreateStatus(ORT_INVALID_PROTOBUF, "Failed to parse model stream!"); + } + + return OrtModel::CreateOrtModelFromProto(std::move(model_proto), model); +} + +OrtStatus* OrtModel::CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model) { + *model = new (std::nothrow) OrtModel(std::move(model_proto)); + if (*model == nullptr) { + return OrtApis::CreateStatus(ORT_ENGINE_ERROR, "Engine failed to create a model!"); + } + + return nullptr; +} + +const ModelInfo* OrtModel::UseModelInfo() const { + return model_info_.get(); +} + +const ONNX_NAMESPACE::ModelProto* OrtModel::UseModelProto() const { + return model_proto_.get(); +} + +std::unique_ptr OrtModel::DetachModelProto() { + return std::move(model_proto_); +} + +ORT_API_STATUS_IMPL(winmla::CreateModelFromPath, const char* model_path, size_t size, OrtModel** out) { + API_IMPL_BEGIN + if (auto status = OrtModel::CreateOrtModelFromPath(model_path, size, out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateModelFromData, void* data, size_t size, OrtModel** out) { + API_IMPL_BEGIN + if (auto status = OrtModel::CreateOrtModelFromData(data, size, out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CloneModel, const OrtModel* in, OrtModel** out) { + API_IMPL_BEGIN + auto model_proto_copy = std::make_unique(*in->UseModelProto()); + if (auto status = OrtModel::CreateOrtModelFromProto(std::move(model_proto_copy), out)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetAuthor, const OrtModel* model, const char** const author, size_t* len) { + API_IMPL_BEGIN + *author = model->UseModelInfo()->author_.c_str(); + *len = model->UseModelInfo()->author_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetName, const OrtModel* model, const char** const name, size_t* len) { + API_IMPL_BEGIN + *name = model->UseModelInfo()->name_.c_str(); + *len = model->UseModelInfo()->name_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetDomain, const OrtModel* model, const char** const domain, size_t* len) { + API_IMPL_BEGIN + *domain = model->UseModelInfo()->domain_.c_str(); + *len = model->UseModelInfo()->domain_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetDescription, const OrtModel* model, const char** const description, size_t* len) { + API_IMPL_BEGIN + *description = model->UseModelInfo()->description_.c_str(); + *len = model->UseModelInfo()->description_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetVersion, const OrtModel* model, int64_t* version) { + API_IMPL_BEGIN + *version = model->UseModelInfo()->version_; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetMetadataCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->model_metadata_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetMetadata, const OrtModel* model, size_t count, const char** const key, + size_t* key_len, const char** const value, size_t* value_len) { + API_IMPL_BEGIN + *key = model->UseModelInfo()->model_metadata_[count].first.c_str(); + *key_len = model->UseModelInfo()->model_metadata_[count].first.size(); + *value = model->UseModelInfo()->model_metadata_[count].second.c_str(); + *value_len = model->UseModelInfo()->model_metadata_[count].second.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->input_features_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputCount, const OrtModel* model, size_t* count) { + API_IMPL_BEGIN + *count = model->UseModelInfo()->output_features_.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputName, const OrtModel* model, size_t index, + const char** input_name, size_t* count) { + API_IMPL_BEGIN + *input_name = model->UseModelInfo()->input_features_[index]->name().c_str(); + *count = model->UseModelInfo()->input_features_[index]->name().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputName, const OrtModel* model, size_t index, + const char** output_name, size_t* count) { + API_IMPL_BEGIN + *output_name = model->UseModelInfo()->output_features_[index]->name().c_str(); + *count = model->UseModelInfo()->output_features_[index]->name().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputDescription, const OrtModel* model, size_t index, + const char** input_description, size_t* count) { + API_IMPL_BEGIN + *input_description = model->UseModelInfo()->input_features_[index]->doc_string().c_str(); + *count = model->UseModelInfo()->input_features_[index]->doc_string().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, const OrtModel* model, size_t index, + const char** output_description, size_t* count) { + API_IMPL_BEGIN + *output_description = model->UseModelInfo()->output_features_[index]->doc_string().c_str(); + *count = model->UseModelInfo()->output_features_[index]->doc_string().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN + if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, const OrtModel* model, size_t index, OrtTypeInfo** type_info) { + API_IMPL_BEGIN + if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) { + return status; + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::ModelEnsureNoFloat16, const OrtModel* model) { + API_IMPL_BEGIN + auto model_info = model->UseModelInfo(); + auto model_proto = model->UseModelProto(); + auto& graph = model_proto->graph(); + + // The model will not contain fp16 operations if: + // 1. The model has no fp16 inputs + // 2. The model has no fp16 initializers + // 3. The model does not create any fp16 intermediary tensors via the Cast (to float16) operator + // 4. The model does not have any fp16 outputs + + // 1. Ensure that The model has no fp16 inputs + for (auto input : model_info->input_features_) { + auto& type = input->type(); + if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { + auto& tensor_type = type.tensor_type(); + if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << input->name() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + + // 2. Ensure that the model has no fp16 initializers + for (int i = 0; i < graph.node_size(); i++) { + auto node = graph.node(i); + if (node.op_type() == "Cast" && node.domain().empty()) { + for (int attribIndex = 0; attribIndex < node.attribute_size(); attribIndex++) { + auto attribute = node.attribute(attribIndex); + if (attribute.name() == "to") { + if (attribute.i() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << node.name().c_str() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + } + } + + // 3. Ensure that the model does not create any fp16 intermediary + // tensors via the Cast (to float16) operator + for (int i = 0; i < graph.initializer_size(); i++) { + auto initializer = graph.initializer(i); + if (initializer.data_type() == onnx::TensorProto::DataType::TensorProto_DataType_FLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << initializer.name().c_str() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + + // 4. Ensure that the model does not have any fp16 outputs + for (auto output : model_info->output_features_) { + auto& type = output->type(); + if (type.value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { + auto& tensor_type = type.tensor_type(); + if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) { + std::stringstream error_message; + error_message << "The model contains a 16-bit input (" + << output->name() + << "), but the current device does not support 16-bit float."; + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, error_message.str().c_str()); + } + } + } + return nullptr; + API_IMPL_END +} + +ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) { + delete ptr; +} \ No newline at end of file diff --git a/winml/adapter/winml_adapter_model.h b/winml/adapter/winml_adapter_model.h new file mode 100644 index 0000000000000..df245f75c7941 --- /dev/null +++ b/winml/adapter/winml_adapter_model.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "winml_adapter_c_api.h" +#include +#include "core/graph/onnx_protobuf.h" + +class ModelInfo; + +struct OrtModel { + public: + static OrtStatus* CreateOrtModelFromPath(const char* path, size_t len, OrtModel** model); + static OrtStatus* CreateOrtModelFromData(void* data, size_t len, OrtModel** model); + static OrtStatus* CreateOrtModelFromProto(std::unique_ptr&& model_proto, OrtModel** model); + const ModelInfo* UseModelInfo() const; + + const onnx::ModelProto* UseModelProto() const; + std::unique_ptr DetachModelProto(); + + private: + OrtModel(std::unique_ptr model_proto); + OrtModel(const OrtModel& other) = delete; + OrtModel& operator=(const OrtModel& other) = delete; + + private: + std::unique_ptr model_proto_; + std::unique_ptr model_info_; +}; diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp new file mode 100644 index 0000000000000..1a65f1e885677 --- /dev/null +++ b/winml/adapter/winml_adapter_session.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "pch.h" + +#include "winml_adapter_c_api.h" +#include "core/session/ort_apis.h" +#include "winml_adapter_apis.h" +#include "core/framework/error_code_helper.h" + +#include "core/session/inference_session.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/onnxruntime_env.h" + +#include "winml_adapter_model.h" +#include "core/framework/utils.h" + +#ifdef USE_DML +#include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h" +#include "abi_custom_registry_impl.h" +#include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" +#endif USE_DML + +namespace winmla = Windows::AI::MachineLearning::Adapter; + +// ORT intentionally requires callers derive from their session class to access +// the protected methods used below. +class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSession { + public: + onnxruntime::common::Status + Load(std::unique_ptr p_model_proto) { + return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); + } + const onnxruntime::SessionState& GetSessionState() { + return *session_state_; + } +}; + +ORT_API_STATUS_IMPL(winmla::CreateSessionWithoutModel, _In_ OrtEnv* env, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** session) { + API_IMPL_BEGIN + std::unique_ptr inference_session; + try { + // Create the inference session + inference_session = std::make_unique(options->value, env->GetLoggingManager()); + } catch (const std::exception& e) { + return OrtApis::CreateStatus(ORT_FAIL, e.what()); + } + + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + if (provider->Type() == onnxruntime::kDmlExecutionProvider) { + if (options->value.enable_mem_pattern) { + // TODO Instead of returning an error, should we set mem pattern to false here and log a warning saying so? + // Doing so would be inconsistent with the Python API that doesn't go through this code path. + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Mem pattern should be disabled when using DML execution provider."); + } + if (options->value.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sequential execution should be enabled when using DML execution provider."); + } + } + provider_list.push_back(std::move(provider)); + } + } + + Status status; + if (options) { + if (!options->custom_op_domains_.empty()) { + status = inference_session->AddCustomOpDomains(options->custom_op_domains_); + if (!status.IsOK()) + return onnxruntime::ToOrtStatus(status); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + inference_session->RegisterExecutionProvider(std::move(provider)); + } + } + + *session = reinterpret_cast(inference_session.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionGetExecutionProvider, _In_ OrtSession* session, _In_ size_t index, _Out_ OrtExecutionProvider** ort_provider) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const auto& session_state = session_protected_load_accessor->GetSessionState(); + auto& provider_id = session_state.GetExecutionProviders().GetIds().at(index); + const auto& provider = session_state.GetExecutionProviders().Get(provider_id); + + *ort_provider = const_cast(reinterpret_cast(provider)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionInitialize, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto status = inference_session->Initialize(); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionLoadAndPurloinModel, _In_ OrtSession* session, _In_ OrtModel* model) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + + auto status = session_protected_load_accessor->Load(model->DetachModelProto()); + + ReleaseModel(model); + + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionStartProfiling, _In_ OrtEnv* env, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + inference_session->StartProfiling(&env->GetLoggingManager()->DefaultLogger()); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionEndProfiling, _In_ OrtSession* session) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + inference_session->EndProfiling(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::SessionRegisterGraphTransformers, _In_ OrtSession* session) { + API_IMPL_BEGIN +#ifdef USE_DML + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + + // Bug 22973884 : Fix issues with BatchNorm + Add and BatchNorm + Mul handling implicit inputs, and move from Winml to ORT + GraphTransformerHelpers::RegisterGraphTransformers(inference_session); +#endif USE_DML + return nullptr; + API_IMPL_END +} + +inline std::list> +GetLotusCustomRegistries(IMLOperatorRegistry* registry) { + if (registry != nullptr) { + // Down-cast to the concrete type. + // The only supported input is the AbiCustomRegistry type. + // Other implementations of IMLOperatorRegistry are forbidden. + auto abi_custom_registry = + static_cast(registry); + + // Get the ORT registry + return abi_custom_registry->GetRegistries(); + } + return {}; +} + +ORT_API_STATUS_IMPL(winmla::SessionRegisterCustomRegistry, _In_ OrtSession* session, _In_ IMLOperatorRegistry* registry) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto custom_registries = GetLotusCustomRegistries(registry); + + // Register + for (auto& custom_registry : custom_registries) { + ORT_THROW_IF_ERROR(inference_session->RegisterCustomRegistry(custom_registry)); + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(winmla::CreateCustomRegistry, _Out_ IMLOperatorRegistry** registry) { + API_IMPL_BEGIN + auto impl = wil::MakeOrThrow(); + *registry = impl.Detach(); + return nullptr; + API_IMPL_END +} + +static OrtDevice GetSessionGetInputDevice(_In_ OrtSession* session, _In_ const char* const input_name) { + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const onnxruntime::SessionState& session_state = session_protected_load_accessor->GetSessionState(); + + std::vector node_info_vec; + session_state.GetInputNodeInfo(input_name, node_info_vec); + const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine + return *node_info.device; +} + +ORT_API_STATUS_IMPL(winmla::SessionGetInputRequiredDeviceId, _In_ OrtSession* session, _In_ const char* const input_name, _Out_ int16_t* device_id) { + auto device = GetSessionGetInputDevice(session, input_name); + *device_id = device.Id(); + return nullptr; +} + +ORT_API_STATUS_IMPL(winmla::ValueGetDeviceId, _In_ OrtValue* ort_value, _Out_ int16_t* device_id) { + auto device = ort_value->Get().Location().device; + *device_id = device.Id(); + return nullptr; +} + +ORT_API_STATUS_IMPL(winmla::SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_ const char* const input_name, + _In_ OrtValue* orig_value, _Outptr_ OrtValue** new_value) { + API_IMPL_BEGIN + auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + auto session_protected_load_accessor = + static_cast(inference_session); + const onnxruntime::SessionState& session_state = session_protected_load_accessor->GetSessionState(); + + auto ort_value = std::make_unique(); + auto status = onnxruntime::utils::CopyOneInputAcrossDevices(session_state, input_name, *orig_value, *ort_value.get()); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + + *new_value = ort_value.release(); + + return nullptr; + API_IMPL_END +} \ No newline at end of file diff --git a/winml/api/Windows.AI.MachineLearning.idl b/winml/api/Windows.AI.MachineLearning.idl index 0380f4a02b7b1..7dddba0afdbb4 100644 --- a/winml/api/Windows.AI.MachineLearning.idl +++ b/winml/api/Windows.AI.MachineLearning.idl @@ -20,7 +20,7 @@ import "windows.storage.idl"; namespace Windows.AI.MachineLearning { - [contractversion(4)] + [contractversion(3)] apicontract MachineLearningContract{}; //! Forward declarations @@ -334,18 +334,6 @@ namespace Windows.AI.MachineLearning TensorKind KeyKind{ get; }; //! Returns the properties of the map's value. ILearningModelFeatureDescriptor ValueDescriptor{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] MapFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind KeyKind, - ILearningModelFeatureDescriptor ValueDescriptor - ); - } } //! \class SequenceFeatureDescriptor @@ -358,17 +346,6 @@ namespace Windows.AI.MachineLearning { //! Gets the properties of the specified feature. ILearningModelFeatureDescriptor ElementDescriptor{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] SequenceFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - ILearningModelFeatureDescriptor ElementDescriptor - ); - } } //! \class TensorFeatureDescriptor @@ -383,23 +360,6 @@ namespace Windows.AI.MachineLearning TensorKind TensorKind{ get; }; //! Returns the count and size of each dimension. Windows.Foundation.Collections.IVectorView Shape{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] TensorFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind TensorKind, - Int64[] Shape, - Boolean HasUnsupportedImageMetadata - ); - //! if this feature is an image but has unsupport image metadata (like a BitmapPixelFormat) - //! you can still use the runtime but without image conversion support. - //! this setting will be 'true' - Boolean HasUnsupportedImageMetadata{ get; } ; - } } //! \class ImageFeatureDescriptor @@ -418,26 +378,6 @@ namespace Windows.AI.MachineLearning UInt32 Width{ get; }; //! The height of the image. UInt32 Height{ get; }; - - [contract(MachineLearningContract, 4)] - { - //! allows creation of feature descriptors outside the runtime (immutable) - [method_name("Create")] ImageFeatureDescriptor( - String Name, - String Description, - Boolean IsRequired, - TensorKind TensorKind, - Int64[] Shape, - Windows.Graphics.Imaging.BitmapPixelFormat BitmapPixelFormat, - Windows.Graphics.Imaging.BitmapAlphaMode BitmapAlphaMode, - UInt32 Width, - UInt32 Height - ); - - //! Returns the data type of the tensor. This is useful if you want to know - //! if it's fp16 or fp32 - TensorKind TensorKind{ get; }; - } } //! \interface ITensor diff --git a/winml/dll/module.cpp b/winml/dll/module.cpp index 531521edc834a..8c7123f880c85 100644 --- a/winml/dll/module.cpp +++ b/winml/dll/module.cpp @@ -6,19 +6,19 @@ #include #include "LearningModelDevice.h" +#include "OnnxruntimeProvider.h" using namespace winrt::Windows::AI::MachineLearning::implementation; -void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const &failure) WI_NOEXCEPT { +void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const& failure) WI_NOEXCEPT { if (!alreadyReported) { winrt::hstring message(failure.pszMessage ? failure.pszMessage : L""); telemetry_helper.LogRuntimeError( - failure.hr, - winrt::to_string(message), - failure.pszFile, - failure.pszFunction, - failure.uLineNumber - ); + failure.hr, + winrt::to_string(message), + failure.pszFile, + failure.pszFunction, + failure.uLineNumber); } } @@ -57,10 +57,10 @@ extern "C" BOOL WINAPI DllMain(_In_ HINSTANCE hInstance, DWORD dwReason, _In_ vo } extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry) try { - *registry = nullptr; - winrt::com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - return adapter->GetCustomRegistry(registry); + winrt::com_ptr engine_factory; + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory.put())); + WINML_THROW_IF_FAILED(engine_factory->CreateCustomRegistry(registry)); + return S_OK; } CATCH_RETURN(); diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp new file mode 100644 index 0000000000000..9aeedb7613099 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" + +#include "OnnxruntimeCpuSessionBuilder.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" + +using namespace Windows::AI::MachineLearning; + +HRESULT OnnxruntimeCpuSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { + engine_factory_ = engine_factory; + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::CreateSessionOptions( + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtSessionOptions* ort_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); + + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + // set the graph optimization level to all (used to be called level 3) + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); + + // Onnxruntime will use half the number of concurrent threads supported on the system + // by default. This causes MLAS to not exercise every logical core. + // We force the thread pool size to be maxxed out to ensure that WinML always + // runs the fastest. + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetIntraOpNumThreads(session_options.get(), std::thread::hardware_concurrency()), + ort_api); + +#ifndef _WIN64 + auto use_arena = false; +#else + auto use_arena = true; +#endif + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); + + // call release() so the underlying OrtSessionOptions object isn't freed + *options = session_options.release(); + + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::CreateSession( + OrtSessionOptions* options, + OrtSession** session) { + RETURN_HR_IF_NULL(E_POINTER, session); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); + + OrtSession* ort_session_raw; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); + + auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); + + *session = ort_session.release(); + + return S_OK; +} + +HRESULT +OnnxruntimeCpuSessionBuilder::Initialize( + OrtSession* session) { + RETURN_HR_IF_NULL(E_INVALIDARG, session); + + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); + + return S_OK; +} diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h new file mode 100644 index 0000000000000..d9f4a12375316 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "OnnxruntimeSessionBuilder.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +class OnnxruntimeCpuSessionBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IOrtSessionBuilder> { + public: + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory); + + HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions** options) override; + + HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions* options, + OrtSession** session) override; + + HRESULT STDMETHODCALLTYPE Initialize( + OrtSession* session) override; + + private: + Microsoft::WRL::ComPtr engine_factory_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/adapter/FeatureDescriptorFactory.cpp b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp similarity index 58% rename from winml/adapter/FeatureDescriptorFactory.cpp rename to winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp index 196ccf517dc64..db4cf60062a45 100644 --- a/winml/adapter/FeatureDescriptorFactory.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp @@ -5,7 +5,7 @@ #include -#include "FeatureDescriptorFactory.h" +#include "OnnxruntimeDescriptorConverter.h" #include "ImageFeatureDescriptor.h" #include "MapFeatureDescriptor.h" #include "SequenceFeatureDescriptor.h" @@ -13,7 +13,11 @@ #include "winrt/windows.foundation.collections.h" #include "winrt/windows.graphics.imaging.h" -#include "WinMLAdapter.h" + +#include "OnnxruntimeEngine.h" + +#include "OnnxruntimeErrors.h" + using namespace winrt::Windows::AI::MachineLearning; // BitmapPixelFormat constants @@ -42,152 +46,64 @@ static const char* c_supported_nominal_ranges[] = namespace Windows::AI::MachineLearning { - -// since this code is now running inside ONNXRUNTIME we need to shortcut -// this a bit when creating winrt objects. This will help. - -/* extern "C" -HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; - -#ifdef _M_IX86 -#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") -#else -#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") -#endif -*/ - -bool starts_with(std::wstring_view value, std::wstring_view match) noexcept -{ - return 0 == value.compare(0, match.size(), match); -} - -EXTERN_C IMAGE_DOS_HEADER __ImageBase; - -std::wstring GetModulePath() -{ - std::wstring val; - wchar_t modulePath[MAX_PATH] = { 0 }; - GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath)); - wchar_t drive[_MAX_DRIVE]; - wchar_t dir[_MAX_DIR]; - wchar_t filename[_MAX_FNAME]; - wchar_t ext[_MAX_EXT]; - _wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); - - val = drive; - val += dir; - - return val; -} - -extern "C" int32_t __stdcall WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept { - *factory = nullptr; - HSTRING classId_hstring = (HSTRING)classId; - std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; - HMODULE library{ nullptr }; - - std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll"; - - if (starts_with(name, L"Windows.AI.MachineLearning.")) - { - const wchar_t* libPath = winmlDllPath.c_str(); - library = LoadLibraryW(libPath); - } - else - { - return RoGetActivationFactory(classId_hstring, iid, factory); - } - - if (!library) - { - return HRESULT_FROM_WIN32(GetLastError()); - } - - using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); - auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); - - if (!call) - { - HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); - WINRT_VERIFY(FreeLibrary(library)); - return hr; - } - - winrt::com_ptr activation_factory; - HRESULT const hr = call(classId_hstring, activation_factory.put_void()); - - if (FAILED(hr)) - { - WINRT_VERIFY(FreeLibrary(library)); - return hr; - } - - if (winrt::guid(iid) != winrt::guid_of()) - { - return activation_factory->QueryInterface(iid, factory); - } - - *factory = activation_factory.detach(); - return S_OK; -} - // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata); static TensorKind -TensorKindFromOnnxDataType( - ONNX_NAMESPACE::TensorProto_DataType dataType) { - using TensorType = ONNX_NAMESPACE::TensorProto_DataType; +TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) { switch (dataType) { - case TensorType::TensorProto_DataType_BOOL: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { return TensorKind::Boolean; } - case TensorType::TensorProto_DataType_STRING: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { return TensorKind::String; } - case TensorType::TensorProto_DataType_FLOAT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { return TensorKind::Float16; } - case TensorType::TensorProto_DataType_FLOAT: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { return TensorKind::Float; } - case TensorType::TensorProto_DataType_DOUBLE: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { return TensorKind::Double; } - case TensorType::TensorProto_DataType_INT8: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { return TensorKind::Int8; } - case TensorType::TensorProto_DataType_INT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { return TensorKind::Int16; } - case TensorType::TensorProto_DataType_INT32: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { return TensorKind::Int32; } - case TensorType::TensorProto_DataType_INT64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { return TensorKind::Int64; } - case TensorType::TensorProto_DataType_UINT8: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { return TensorKind::UInt8; } - case TensorType::TensorProto_DataType_UINT16: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { return TensorKind::UInt16; } - case TensorType::TensorProto_DataType_UINT32: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { return TensorKind::UInt32; } - case TensorType::TensorProto_DataType_UINT64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { return TensorKind::UInt64; } - case TensorType::TensorProto_DataType_COMPLEX64: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: { return TensorKind::Complex64; } - case TensorType::TensorProto_DataType_COMPLEX128: { + case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { return TensorKind::Complex128; } - default: { return TensorKind::Undefined; } + default: { + return TensorKind::Undefined; + } } } @@ -240,26 +156,10 @@ TensorKindToString(TensorKind tensorKind) { return "complex128"; } case TensorKind::Undefined: - default: { return "undefined"; } - } -} - -static std::vector -ConvertShapeProtoToVector( - const ::onnx::TensorShapeProto& shape_proto) { - std::vector shape; - for (int i = 0; i < shape_proto.dim_size(); i++) { - auto& dim = shape_proto.dim(i); - if (dim.has_dim_param()) { - shape.push_back(-1); - } else if (dim.has_dim_value()) { - shape.push_back(dim.dim_value()); - } else { - winrt::throw_hresult(E_INVALIDARG); + default: { + return "undefined"; } } - - return shape; } static const char* @@ -410,16 +310,16 @@ enum class TensorType { Tensor_Data, static TensorType GetTensorType( - const ::onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + OrtTypeInfo* type_info, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); + const char* denotation; + size_t len; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetDenotationFromTypeInfo(type_info, &denotation, &len), + engine_factory->UseOrtApi()); - THROW_HR_IF_MSG( - E_FAIL, - type_proto.has_tensor_type() == false, - "Malformed onnx file."); - - auto has_image_denotation = type_proto.denotation() == "IMAGE"; + constexpr char c_image[] = "IMAGE"; + auto has_image_denotation = strncmp(denotation, c_image, _countof(c_image)) == 0; if (!has_image_denotation) { return TensorType::Tensor_Data; } @@ -430,9 +330,15 @@ GetTensorType( // Check if the tensor value_info_proto is of type float. // IMAGE tensors MUST be of type float - const auto& tensor_type = type_proto.tensor_type(); - auto tensor_kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + + auto tensor_kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); auto is_float_tensor = tensor_kind == TensorKind::Float; if (!is_float_tensor) { log_stream << "Unsupported image with " << TensorKindToString(tensor_kind) @@ -471,7 +377,7 @@ GetTensorType( has_unsupported_image_metadata); if (is_tensor_improperly_annotated_as_image) { - TraceLoggingWrite(winmla::winml_trace_logging_provider, + TraceLoggingWrite(winml_trace_logging_provider, "WinMLInputValidation", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TraceLoggingLevel(WINEVENT_LEVEL_WARNING), @@ -491,21 +397,35 @@ GetTensorType( static winml::ILearningModelFeatureDescriptor CreateTensorFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata, bool has_unsupported_image_metadata) { - const auto& type_proto = value_info_proto->type(); - const auto& tensor_type = type_proto.tensor_type(); - auto shape = WinML::ConvertShapeProtoToVector(tensor_type.shape()); - auto kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); - - TensorFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), // description - value_info_proto->name().empty() == false, // is_required + auto type_info = feature_descriptor->type_info_.get(); + + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + size_t num_dims; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); + + auto shape = std::vector(num_dims); + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + + auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); + + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, // description kind, shape, + feature_descriptor->name_length_ > 0, // is_required has_unsupported_image_metadata); return descriptor.as(); @@ -513,13 +433,27 @@ CreateTensorFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateImageFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - const auto& tensor_type = type_proto.tensor_type(); - auto shape = WinML::ConvertShapeProtoToVector(tensor_type.shape()); - auto kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(tensor_type.elem_type())); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtTensorTypeAndShapeInfo* tensor_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->CastTypeInfoToTensorInfo(type_info, &tensor_info), + engine_factory->UseOrtApi()); + + size_t num_dims; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensionsCount(tensor_info, &num_dims), + engine_factory->UseOrtApi()); + + auto shape = std::vector(num_dims); + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetDimensions(tensor_info, shape.data(), shape.size()), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType tensor_element_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), + engine_factory->UseOrtApi()); + auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); // pixel format and alpha auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key); @@ -527,18 +461,13 @@ CreateImageFeatureDescriptor( auto pixel_format = format_info.first; auto alpha_mode = format_info.second; - // paulm: commenting this out during layering. gamma and nominal are never used - // since we only support one of them. if a non support one is set, they all fall back - // to TensorFeatureDescriptor (invalid image metadata) -#ifdef DONE_LAYERING // color space gamma value - auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); - auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); + auto color_space_gamma_value = FetchMetadataValueOrNull(metadata, c_color_space_key); + auto color_space_gamma = CreateImageColorSpaceGamma(color_space_gamma_value); // nominal range - auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); - auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); -#endif + auto nominal_range_value = FetchMetadataValueOrNull(metadata, c_nominal_range_key); + auto nominal_range = CreateImageNominalPixelRange(nominal_range_value); // The current code assumes that the shape will be in NCHW. // Should the model metadata be read instead??? @@ -546,42 +475,59 @@ CreateImageFeatureDescriptor( const int c_width_dimension = 3; auto height = static_cast(shape[c_height_dimension]); auto width = static_cast(shape[c_width_dimension]); - ImageFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_required + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, kind, shape, + feature_descriptor->name_length_ > 0, // is_required pixel_format, alpha_mode, width, - height); + height, + nominal_range, + color_space_gamma); return descriptor.as(); } static winml::ILearningModelFeatureDescriptor CreateMapFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - auto type_proto_map = type_proto.map_type(); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtMapTypeInfo* map_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToMapTypeInfo(type_info, &map_info), + engine_factory->UseOrtApi()); + + ONNXTensorElementDataType map_key_data_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapKeyType(map_info, &map_key_data_type), + engine_factory->UseOrtApi()); - auto key_kind = WinML::TensorKindFromOnnxDataType( - onnx::TensorProto_DataType(type_proto_map.key_type())); + auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type); - onnx::ValueInfoProto dummy_value_info_proto; - dummy_value_info_proto.set_name(value_info_proto->name().c_str()); - dummy_value_info_proto.set_doc_string(value_info_proto->doc_string().c_str()); - *dummy_value_info_proto.mutable_type() = type_proto_map.value_type(); + OrtTypeInfo* map_value_type_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetMapValueType(map_info, &map_value_type_info), + engine_factory->UseOrtApi()); + + UniqueOrtTypeInfo unique_map_value_type_info(map_value_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); + + OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; + dummy_ort_value_info_wrapper.description_ = feature_descriptor->description_; + dummy_ort_value_info_wrapper.description_length_ = feature_descriptor->description_length_; + dummy_ort_value_info_wrapper.name_ = feature_descriptor->name_; + dummy_ort_value_info_wrapper.name_length_ = feature_descriptor->name_length_; + dummy_ort_value_info_wrapper.type_info_ = std::move(unique_map_value_type_info); auto value_descriptor = - CreateFeatureDescriptor(&dummy_value_info_proto, metadata); + CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); - MapFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_rRequired + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, + feature_descriptor->name_length_ > 0, // is_required key_kind, value_descriptor); return descriptor.as(); @@ -589,24 +535,35 @@ CreateMapFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateSequenceFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); - // assert(typeProto->has_sequence_type()); - auto type_proto_sequence = type_proto.sequence_type(); + auto type_info = feature_descriptor->type_info_.get(); + + const OrtSequenceTypeInfo* sequence_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->CastTypeInfoToSequenceTypeInfo(type_info, &sequence_info), + engine_factory->UseOrtApi()); + + OrtTypeInfo* sequence_element_type_info; + THROW_IF_NOT_OK_MSG(engine_factory->UseWinmlAdapterApi()->GetSequenceElementType(sequence_info, &sequence_element_type_info), + engine_factory->UseOrtApi()); - onnx::ValueInfoProto dummy_value_info_proto; - dummy_value_info_proto.set_name(value_info_proto->name().c_str()); - dummy_value_info_proto.set_doc_string(value_info_proto->doc_string().c_str()); - *dummy_value_info_proto.mutable_type() = type_proto_sequence.elem_type(); + UniqueOrtTypeInfo unique_sequence_element_type_info(sequence_element_type_info, engine_factory->UseOrtApi()->ReleaseTypeInfo); + + OnnxruntimeValueInfoWrapper dummy_ort_value_info_wrapper; + dummy_ort_value_info_wrapper.description_ = feature_descriptor->description_; + dummy_ort_value_info_wrapper.description_length_ = feature_descriptor->description_length_; + dummy_ort_value_info_wrapper.name_ = feature_descriptor->name_; + dummy_ort_value_info_wrapper.name_length_ = feature_descriptor->name_length_; + dummy_ort_value_info_wrapper.type_info_ = std::move(unique_sequence_element_type_info); auto element_descriptor = - CreateFeatureDescriptor(&dummy_value_info_proto, metadata); + CreateFeatureDescriptor(engine_factory, &dummy_ort_value_info_wrapper, metadata); - SequenceFeatureDescriptor descriptor( - WinML::Strings::HStringFromUTF8(value_info_proto->name()), - WinML::Strings::HStringFromUTF8(value_info_proto->doc_string()), - value_info_proto->name().empty() == false, // is_required + auto descriptor = winrt::make( + feature_descriptor->name_, + feature_descriptor->description_, + feature_descriptor->name_length_ > 0, // is_required element_descriptor); return descriptor.as(); @@ -614,36 +571,43 @@ CreateSequenceFeatureDescriptor( static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( - const onnx::ValueInfoProto* value_info_proto, + OnnxruntimeEngineFactory* engine_factory, + const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata) { - const auto& type_proto = value_info_proto->type(); + auto type_info = feature_descriptor->type_info_.get(); + + ONNXType onnx_type; + THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetOnnxTypeFromTypeInfo(type_info, &onnx_type), + engine_factory->UseOrtApi()); - using ValueCase = ::onnx::TypeProto::ValueCase; - switch (type_proto.value_case()) { - case ValueCase::kTensorType: { - auto tensor_type = - GetTensorType(value_info_proto, metadata); + switch (onnx_type) { + case ONNXType::ONNX_TYPE_TENSOR: { + auto tensor_type = GetTensorType(engine_factory, type_info, metadata); if (tensor_type == TensorType::Tensor_Image) { return CreateImageFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } else { auto has_unsupported_image_metadata = tensor_type == TensorType::Tensor_Data_UnsupportedImageMetadata; return CreateTensorFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata, has_unsupported_image_metadata); } } - case ValueCase::kMapType: { + case ONNXType::ONNX_TYPE_MAP: { return CreateMapFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } - case ValueCase::kSequenceType: { + case ONNXType::ONNX_TYPE_SEQUENCE: { return CreateSequenceFeatureDescriptor( - value_info_proto, + engine_factory, + feature_descriptor, metadata); } default: @@ -651,18 +615,17 @@ CreateFeatureDescriptor( } } -FeatureDescriptorFactory::FeatureDescriptorFactory( - const std::unordered_map& metadata) : metadata_(metadata) {} +OnnxruntimeDescriptorConverter::OnnxruntimeDescriptorConverter( + OnnxruntimeEngineFactory* engine_factory, + const std::unordered_map& metadata) : engine_factory_(engine_factory), metadata_(metadata) {} wfc::IVector -FeatureDescriptorFactory::CreateDescriptorsFromValueInfoProtos( - const std::vector& value_info_protos) { - auto features = - winrt::single_threaded_vector(); - - for (auto value_info_proto : value_info_protos) { - auto descriptor = WinML::CreateFeatureDescriptor(value_info_proto, metadata_); - features.Append(descriptor); +OnnxruntimeDescriptorConverter::ConvertToLearningModelDescriptors(const std::vector& descriptors) { + auto features = winrt::single_threaded_vector(); + + for (const auto& descriptor : descriptors) { + auto learning_model_descriptor = WinML::CreateFeatureDescriptor(engine_factory_.Get(), &descriptor, metadata_); + features.Append(learning_model_descriptor); } return features; diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h new file mode 100644 index 0000000000000..4fab8bc443e7a --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#pragma once + +#include "pch.h" + +namespace Windows::AI::MachineLearning { + +struct OnnxruntimeValueInfoWrapper { + OnnxruntimeValueInfoWrapper() : type_info_(UniqueOrtTypeInfo(nullptr, nullptr)) {} + const char* name_ = nullptr; + size_t name_length_ = 0; + const char* description_ = nullptr; + size_t description_length_ = 0; + UniqueOrtTypeInfo type_info_; +}; + +class OnnxruntimeEngineFactory; + +struct OnnxruntimeDescriptorConverter { + OnnxruntimeDescriptorConverter( + OnnxruntimeEngineFactory* engine_factory, + const std::unordered_map& model_metadata); + + wfc::IVector + ConvertToLearningModelDescriptors(const std::vector& descriptors); + + private: + Microsoft::WRL::ComPtr engine_factory_; + const std::unordered_map& metadata_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp new file mode 100644 index 0000000000000..8f23e6864f73e --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" + +#ifdef USE_DML + +#include "OnnxruntimeDmlSessionBuilder.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" +#include "LearningModelDevice.h" + +using namespace Windows::AI::MachineLearning; + +HRESULT OnnxruntimeDmlSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue) { + engine_factory_ = engine_factory; + device_.copy_from(device); + queue_.copy_from(queue); + return S_OK; +} + +HRESULT +OnnxruntimeDmlSessionBuilder::CreateSessionOptions( + OrtSessionOptions** options) { + RETURN_HR_IF_NULL(E_POINTER, options); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtSessionOptions* ort_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateSessionOptions(&ort_options), + ort_api); + + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + // set the graph optimization level to all (used to be called level 3) + RETURN_HR_IF_NOT_OK_MSG(ort_api->SetSessionGraphOptimizationLevel(session_options.get(), GraphOptimizationLevel::ORT_ENABLE_ALL), + ort_api); + + // Disable the mem pattern session option for DML. It will cause problems with how memory is allocated. + RETURN_HR_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), + ort_api); + + // Request the dml ep + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML(session_options.get(), device_.get(), queue_.get()), + ort_api); + +#ifndef _WIN64 + auto use_arena = false; +#else + auto use_arena = true; +#endif + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_CPU(session_options.get(), use_arena), + ort_api); + + // call release() so the underlying OrtSessionOptions object isn't freed + *options = session_options.release(); + + return S_OK; +} + +HRESULT OnnxruntimeDmlSessionBuilder::CreateSession( + OrtSessionOptions* options, + OrtSession** session) { + RETURN_HR_IF_NULL(E_POINTER, session); + + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + RETURN_IF_FAILED(engine_factory_->GetOrtEnvironment(&ort_env)); + + OrtSession* ort_session_raw; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CreateSessionWithoutModel(ort_env, options, &ort_session_raw), + engine_factory_->UseOrtApi()); + auto ort_session = UniqueOrtSession(ort_session_raw, ort_api->ReleaseSession); + + *session = ort_session.release(); + + return S_OK; +} + +HRESULT OnnxruntimeDmlSessionBuilder::Initialize( + OrtSession* session) { + RETURN_HR_IF_NULL(E_INVALIDARG, session); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionInitialize(session), + engine_factory_->UseOrtApi()); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true), + engine_factory_->UseOrtApi()); + + // Flush the D3D12 work from the DML execution provider + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +#endif USE_DML \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h new file mode 100644 index 0000000000000..0f651e823a532 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "OnnxruntimeSessionBuilder.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +class OnnxruntimeDmlSessionBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IOrtSessionBuilder> { + public: + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue); + + HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions** options) override; + + HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions* options, + OrtSession** session) override; + + HRESULT STDMETHODCALLTYPE Initialize( + OrtSession* session) override; + + private: + Microsoft::WRL::ComPtr engine_factory_; + winrt::com_ptr device_; + winrt::com_ptr queue_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp new file mode 100644 index 0000000000000..05b114bdaf6d7 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -0,0 +1,1265 @@ +#include "pch.h" + +#include "OnnxruntimeEngine.h" + +#include "PheonixSingleton.h" +#include "OnnxruntimeEnvironment.h" +#include "OnnxruntimeEngineBuilder.h" +#include "OnnxruntimeModel.h" +#include "OnnxruntimeSessionBuilder.h" +#include "OnnxruntimeErrors.h" + +using namespace WinML; + +static const OrtApi* GetVersionedOrtApi() { + static const uint32_t ort_version = 1; + const auto ort_api_base = OrtGetApiBase(); + return ort_api_base->GetApi(ort_version); +} + +static const WinmlAdapterApi* GetVersionedWinmlAdapterApi() { + return OrtGetWinMLAdapter(GetVersionedOrtApi()); +} + +static ONNXTensorElementDataType +ONNXTensorElementDataTypeFromTensorKind(winml::TensorKind kind) { + switch (kind) { + case winml::TensorKind::Boolean: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + } + case winml::TensorKind::String: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } + case winml::TensorKind::Float16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } + case winml::TensorKind::Float: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + case winml::TensorKind::Double: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } + case winml::TensorKind::Int8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } + case winml::TensorKind::Int16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + } + case winml::TensorKind::Int32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } + case winml::TensorKind::Int64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + case winml::TensorKind::UInt8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } + case winml::TensorKind::UInt16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + } + case winml::TensorKind::UInt32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + } + case winml::TensorKind::UInt64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + } + case winml::TensorKind::Complex64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; + } + case winml::TensorKind::Complex128: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; + } + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } + } +} + +OnnxruntimeValue::OnnxruntimeValue() : value_(nullptr, nullptr), allocator_(nullptr, nullptr) {} + +OnnxruntimeValue::~OnnxruntimeValue() { + value_.reset(nullptr); + allocator_.reset(nullptr); +} + +HRESULT OnnxruntimeValue::RuntimeClassInitialize(OnnxruntimeEngine* engine, UniqueOrtValue&& ort_value, UniqueOrtAllocator&& allocator) { + engine_ = engine; + value_ = std::move(ort_value); + allocator_ = std::move(allocator); + + return S_OK; +} + +HRESULT OnnxruntimeValue::IsEmpty(bool* out) { + *out = UseOrtValue() == nullptr; + return S_OK; +} + +HRESULT OnnxruntimeValue::IsCpu(bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + OrtMemoryInfo* ort_memory_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetValueMemoryInfo(value_.get(), &ort_memory_info), + ort_api); + auto memory_info = UniqueOrtMemoryInfo(ort_memory_info, ort_api->ReleaseMemoryInfo); + + const char* name; + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetName(memory_info.get(), &name), + ort_api); + + OrtMemType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->MemoryInfoGetMemType(memory_info.get(), &type), + ort_api); + + *out = !strcmp(name, "Cpu") || + type == OrtMemType::OrtMemTypeCPUOutput || + type == OrtMemType::OrtMemTypeCPUInput; + return S_OK; +} + +static int64_t ShapeSize(const int64_t* shape, size_t count) { + // for each dim + int64_t size = 1; + for (int i = 0; i < count; i++) { + // find out it's total size + size *= shape[i]; + // make sure there are no invalid dimensions (-1 or any invalid shape) + THROW_HR_IF(E_INVALIDARG, shape[i] <= 0); + } + return size; +} + +static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value, + OrtTensorTypeAndShapeInfo* type_and_shape_info) { + std::vector out; + + size_t size; + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info, &size), + ort_api); + + std::vector shape(size); + THROW_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info, &shape[0], size), + ort_api); + + auto length = ShapeSize(shape.data(), shape.size()); + + // make a big buffer to hold all the string data + size_t buffer_length; + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorDataLength(ort_value, &buffer_length), + ort_api); + + std::vector strings; + std::unique_ptr buffer(new uint8_t[buffer_length]); + std::vector offsets(length); + + THROW_IF_NOT_OK_MSG(ort_api->GetStringTensorContent(ort_value, buffer.get(), buffer_length, offsets.data(), offsets.size()), + ort_api); + + // now go build all the strings + for (auto i = 0; i < length; ++i) { + size_t str_len = 0; + // are we on the last one? + if (i == (length - 1)) { + str_len = buffer_length - offsets[i]; + } else { + str_len = offsets[i + 1] - offsets[i]; + } + strings.push_back(std::string_view(reinterpret_cast(buffer.get() + offsets[i]), str_len)); + } + + return std::make_shared>(std::move(strings), std::move(buffer)); +} + +HRESULT OnnxruntimeValue::GetResource(WinML::Resource& out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + void* mutable_data = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(value_.get(), &mutable_data), + ort_api); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(engine_->UseOrtSession(), 0, &ort_provider), + ort_api); + + bool is_cpu = false; + if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) { + void* resource; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data, + reinterpret_cast(&resource)), + ort_api); + out = WinML::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); + } else { + int is_tensor; + RETURN_HR_IF_NOT_OK_MSG(ort_api->IsTensor(value_.get(), &is_tensor), + ort_api); + if (is_tensor == 0) { + out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + return S_OK; + } + + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + ONNXTensorElementDataType data_type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type), + ort_api); + + if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + auto strings = GetStrings(ort_api, value_.get(), info); + auto string_data = strings->first.data(); + out = WinML::Resource(string_data, [capture_strings = strings](void*) { /*This deleter does nothing but capture the strings, which extends the lifetime of the returned strings.*/ }); + } else { + out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + } + } + return S_OK; +} + +HRESULT OnnxruntimeValue::IsTensor(bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + + ONNXType type = ONNXType::ONNX_TYPE_UNKNOWN; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueType(value_.get(), &type), + ort_api); + *out = type == ONNXType::ONNX_TYPE_TENSOR; + return S_OK; +} + +HRESULT OnnxruntimeValue::IsOfTensorType(winml::TensorKind kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + ONNXTensorElementDataType data_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorElementType(type_and_shape_info.get(), &data_type), + ort_api); + + *out = data_type == ONNXTensorElementDataTypeFromTensorKind(kind); + return S_OK; +} + +HRESULT OnnxruntimeValue::GetTensorShape(std::vector& shape_vector) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + OrtTensorTypeAndShapeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorTypeAndShape(value_.get(), &info), + ort_api); + auto type_and_shape_info = UniqueOrtTensorTypeAndShapeInfo(info, ort_api->ReleaseTensorTypeAndShapeInfo); + + size_t size; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(type_and_shape_info.get(), &size), + ort_api); + + std::vector shape(size); + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetDimensions(type_and_shape_info.get(), &shape[0], size), + ort_api); + + shape_vector = std::move(shape); + return S_OK; +} + +static bool EnsureMapTypeInfo(OnnxruntimeEngine* engine, OrtTypeInfo* type_info, winml::TensorKind key_kind, winml::TensorKind value_kind) { + auto ort_api = engine->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine->GetEngineFactory()->UseWinmlAdapterApi(); + + const OrtMapTypeInfo* map_info; + THROW_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToMapTypeInfo(type_info, &map_info), + ort_api); + + ONNXTensorElementDataType map_key_type; + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapKeyType(map_info, &map_key_type), + ort_api); + + if (map_key_type == ONNXTensorElementDataTypeFromTensorKind(key_kind)) { + OrtTypeInfo* value_info; + THROW_IF_NOT_OK_MSG(winml_adapter_api->GetMapValueType(map_info, &value_info), + ort_api); + auto map_value_info = UniqueOrtTypeInfo(value_info, ort_api->ReleaseTypeInfo); + + const OrtTensorTypeAndShapeInfo* value_tensor_info = nullptr; + THROW_IF_NOT_OK_MSG(ort_api->CastTypeInfoToTensorInfo(map_value_info.get(), &value_tensor_info), + ort_api); + + if (value_tensor_info) { + ONNXTensorElementDataType map_value_tensor_type; + THROW_IF_NOT_OK_MSG(ort_api->GetTensorElementType(value_tensor_info, &map_value_tensor_type), + ort_api); + + if (map_value_tensor_type == ONNXTensorElementDataTypeFromTensorKind(value_kind)) { + size_t num_dims; + THROW_IF_NOT_OK_MSG(ort_api->GetDimensionsCount(value_tensor_info, &num_dims), + ort_api); + + return num_dims == 0; + } + } + } + return false; +} + +HRESULT OnnxruntimeValue::IsOfMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + + OrtTypeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); + auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); + + ONNXType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); + + if (type == ONNXType::ONNX_TYPE_MAP) { + *out = EnsureMapTypeInfo(engine_.Get(), unique_type_info.get(), key_kind, value_kind); + } + + *out = false; + + return S_OK; +} + +HRESULT OnnxruntimeValue::IsOfVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) { + auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); + auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); + + OrtTypeInfo* info = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTypeInfo(value_.get(), &info), + ort_api); + auto unique_type_info = UniqueOrtTypeInfo(info, ort_api->ReleaseTypeInfo); + + ONNXType type; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetOnnxTypeFromTypeInfo(unique_type_info.get(), &type), + ort_api); + + if (type == ONNXType::ONNX_TYPE_SEQUENCE) { + const OrtSequenceTypeInfo* sequence_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CastTypeInfoToSequenceTypeInfo(unique_type_info.get(), &sequence_info), + ort_api); + + OrtTypeInfo* element_info; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetSequenceElementType(sequence_info, &element_info), + ort_api); + auto unique_element_info = UniqueOrtTypeInfo(element_info, ort_api->ReleaseTypeInfo); + + *out = EnsureMapTypeInfo(engine_.Get(), unique_element_info.get(), key_kind, value_kind); + } + return S_OK; +} + +HRESULT OnnxruntimeValue::SetParameter(IUnknown* param) { + param_ = param; + return S_OK; +} + +OrtValue* OnnxruntimeValue::UseOrtValue() { + return value_.get(); +} + +HRESULT OnnxruntimeValue::AssignOrtValue(OrtValue* in) { + value_.reset(in); + return S_OK; +} + +OnnxruntimeEngine::OnnxruntimeEngine() : session_(nullptr, nullptr) { +} + +HRESULT OnnxruntimeEngine::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, + UniqueOrtSession&& session, + IOrtSessionBuilder* session_builder) { + engine_factory_ = engine_factory; + session_ = std::move(session); + session_builder_ = session_builder; + return S_OK; +} + +HRESULT OnnxruntimeEngine::LoadModel(_In_ IModel* model) { + Microsoft::WRL::ComPtr onnxruntime_model; + RETURN_IF_FAILED(model->QueryInterface(IID_PPV_ARGS(&onnxruntime_model))); + + OrtModel* ort_model; + RETURN_IF_FAILED(onnxruntime_model->DetachOrtModel(&ort_model)); + + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionLoadAndPurloinModel(session_.get(), ort_model), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::Initialize() { + RETURN_IF_FAILED(session_builder_->Initialize(session_.get())); + return S_OK; +} + +HRESULT OnnxruntimeEngine::RegisterGraphTransformers() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterGraphTransformers(session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::RegisterCustomRegistry(IMLOperatorRegistry* registry) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionRegisterCustomRegistry(session_.get(), registry), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::EndProfiling() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionEndProfiling(session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::StartProfiling() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtEnv* ort_env; + engine_factory_->GetOrtEnvironment(&ort_env); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionStartProfiling(ort_env, session_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::FlushContext() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), + engine_factory_->UseOrtApi()); + return S_OK; +} + +HRESULT OnnxruntimeEngine::TrimUploadHeap() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::CopyValueAcrossDevices(IValue* src, IValue* dest) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + auto src_value = static_cast(src); + auto dest_value = static_cast(dest); + + bool is_empty; + auto has_null_source = (SUCCEEDED(src_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_source); + + auto has_null_dest = (SUCCEEDED(dest_value->IsEmpty(&is_empty)) && is_empty); + RETURN_HR_IF(E_FAIL, has_null_dest); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCopyTensor(ort_provider, src_value->UseOrtValue(), dest_value->UseOrtValue()), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +HRESULT OnnxruntimeEngine::Sync() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider), + engine_factory_->UseOrtApi()); + + return S_OK; +} + +OrtSession* OnnxruntimeEngine::UseOrtSession() { + return session_.get(); +} + +const OrtApi* OnnxruntimeEngine::UseOrtApi() { + return engine_factory_->UseOrtApi(); +} + +OnnxruntimeEngineFactory* OnnxruntimeEngine::GetEngineFactory() { + return engine_factory_.Get(); +} + +/* +* OnnxruntimeEngine::CreateTensorValue +* +* Used by callers like ImageFeatureValue to allocate a cpu or gpu OrtValue with ORT owned memory. +* In the image feature value case, tensorization creates temporary buffers, and will need to copy the value from +* its source location to the ort value. Since a copy is required, there is need to preserve the caller's memory locations. +* We simply allocate memory with ORT and copy the tensorized values into it. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValue(const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderAllocator(ort_provider, &ort_allocator), + engine_factory_->UseOrtApi()); + + auto unique_allocator = UniqueOrtAllocator(ort_allocator, winml_adapter_api->FreeProviderAllocator); // the release here should probably not return anything + + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorAsOrtValue(unique_allocator.get(), shape, count, ONNXTensorElementDataTypeFromTensorKind(kind), &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), std::move(unique_allocator))); + return S_OK; +} + +using DmlAllocatorResource = std::unique_ptr; +class DmlAllocatorWrapper : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IUnknown> { + public: + DmlAllocatorWrapper() : dml_resource_(nullptr, nullptr) {} + + HRESULT RuntimeClassInitialize(DmlAllocatorResource&& dml_resource) { + dml_resource_ = std::move(dml_resource); + return S_OK; + } + + private: + DmlAllocatorResource dml_resource_; +}; + +/* +* OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource +* +* Used by callers like TensorBase to allocate a gpu OrtValue based on a called owned ID3D12Resource. +* WinML cannot use ORT allocators here since they will allocate the ID3D12Resource and force a copy from the user provided value. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource(ID3D12Resource* d3d_resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtExecutionProvider* ort_provider; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider), + engine_factory_->UseOrtApi()); + + OrtMemoryInfo* dml_memory = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->GetProviderMemoryInfo(ort_provider, &dml_memory), + engine_factory_->UseOrtApi()); + + void* dml_allocator_resource; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlCreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), + engine_factory_->UseOrtApi()); + + auto unique_dml_allocator_resource = + DmlAllocatorResource(dml_allocator_resource, + [](void* ptr) { + GetVersionedWinmlAdapterApi()->DmlFreeGPUAllocation(ptr); + }); + + // create the OrtValue as a tensor letting ort know that we own the data buffer + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + dml_memory, + unique_dml_allocator_resource.get(), + d3d_resource->GetDesc().Width, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + + Microsoft::WRL::ComPtr out_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&out_value, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + + // Cache the allocator on the value so it destructs appropriately when the value is dropped + Microsoft::WRL::ComPtr dml_allocator_resource_wrapper; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&dml_allocator_resource_wrapper, std::move(unique_dml_allocator_resource))); + + RETURN_IF_FAILED(out_value->SetParameter(dml_allocator_resource_wrapper.Get())); + + *out = out_value.Detach(); + + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy +* +* Used by callers like TensorString to allocate a cpu OrtValue and populate the contents with use specified data. +* WinML cannot use CreateTensorWithDataAsOrtValue since externally allocated strings are not supported on the c-abi. +* The c-abi string implementation requires a copy the external buffer into its own internal std::string copy. +* In addition, strings have different APIs on the c-abi like FillStringTensor to populate the buffer, and so strings +* have a different calling pattern than other Tensor types of simple data types. +*/ +HRESULT OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy(const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + RETURN_IF_FAILED(CreateTensorValue(shape, count, winml::TensorKind::String, out)); + + auto ort_value = reinterpret_cast(*out)->UseOrtValue(); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(ort_value, reinterpret_cast(data), num_elements), + ort_api); + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateTensorValueFromExternalBuffer +* +* Used by callers like TensorBase to allocate a cpu OrtValue that is backed by caller owned memory. +*/ +HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalBuffer(void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + + if (kind == winml::TensorKind::String) { + // String buffers cannot be passed into the ort api directly because ort c-api tensor strings cannot be backed by external memory + return E_NOTIMPL; + } + + // TODO: what is the difference between the device allocator and the arena allocator? + OrtMemoryInfo* cpu_memory; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory), + ort_api); + + OrtValue* ort_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateTensorWithDataAsOrtValue( + cpu_memory, + data, + size_in_bytes, + shape, + count, + ONNXTensorElementDataTypeFromTensorKind(kind), + &ort_value), + ort_api); + auto unique_value = UniqueOrtValue(ort_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +/* +* OnnxruntimeEngine::CreateNullValue +* +* Used by callers like TensorBase and the binding object to allocate a cpu OrtValue that is empty. +* This is used for WinML unbound outputs. +*/ +HRESULT OnnxruntimeEngine::CreateNullValue(_Out_ IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto unique_value = UniqueOrtValue(nullptr, ort_api->ReleaseValue); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +template +struct AbiTypeInfo { + using CppWinRTType = TAbiType; + using OrtType = TAbiType; + using ResourceType = TAbiType; +}; + +template <> +struct AbiTypeInfo { + using CppWinRTType = winrt::hstring; + using OrtType = const char*; + using ResourceType = std::string_view; +}; + +template +typename auto CppwinrtTypeToOrtType(TCppwinrtType raw) { + return raw; +} + +template <> +typename auto CppwinrtTypeToOrtType(winrt::hstring raw) { + return WinML::Strings::UTF8FromHString(raw); +} + +template +typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::ResourceType value) { + return value; +} + +template <> +typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::ResourceType value) { + return WinML::Strings::HStringFromUTF8(value.data(), value.size()); +} + +template +auto CastToWinrtMap(IInspectable* map_insp) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + + ::winrt::Windows::Foundation::IInspectable map_inspectable; + ::winrt::Windows::Foundation::Collections::IMap map; + winrt::copy_from_abi(map_inspectable, map_insp); + map_inspectable.as(map); + return map; +} + +template +auto CastToWinrtSequenceOfMaps(IInspectable* sequence_insp) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + + using cppwinrt_element_map_type = ::winrt::Windows::Foundation::Collections::IMap; + using cppwinrt_sequence_type = ::winrt::Windows::Foundation::Collections::IVector; + cppwinrt_sequence_type sequence; + ::winrt::Windows::Foundation::IInspectable sequence_inspectable; + winrt::copy_from_abi(sequence_inspectable, sequence_insp); + sequence_inspectable.as(sequence); + return sequence; +} + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* keys_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); + + AbiTypeInfo::OrtType* values_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + for (const auto& pair : map) { + keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key()); + values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value()); + index++; + } + return S_OK; + } +}; + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* values_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(values_ort_value, reinterpret_cast(&values_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector keys; + for (const auto& pair : map) { + keys.push_back(CppwinrtTypeToOrtType(pair.Key())); + values_mutable_data[index] = CppwinrtTypeToOrtType(pair.Value()); + index++; + } + + std::vector raw_values; + std::transform( + keys.begin(), + keys.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); + + return S_OK; + } +}; + +template +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + AbiTypeInfo::OrtType* keys_mutable_data; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetTensorMutableData(keys_ort_value, reinterpret_cast(&keys_mutable_data)), + ort_api); + + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector values; + for (const auto& pair : map) { + keys_mutable_data[index] = CppwinrtTypeToOrtType(pair.Key()); + values.push_back(CppwinrtTypeToOrtType(pair.Value())); + index++; + } + + std::vector raw_values; + std::transform( + values.begin(), + values.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_values.data(), raw_values.size()), + ort_api); + return S_OK; + } +}; + +template <> +struct FillMapTensors { + static HRESULT Run(const OrtApi* ort_api, IInspectable* map_insp, OrtValue* keys_ort_value, OrtValue* values_ort_value) { + auto map = CastToWinrtMap(map_insp); + size_t index = 0; + std::vector keys; + std::vector values; + for (const auto& pair : map) { + keys.push_back(CppwinrtTypeToOrtType(pair.Key())); + values.push_back(CppwinrtTypeToOrtType(pair.Value())); + index++; + } + + std::vector raw_keys; + std::transform( + keys.begin(), + keys.end(), + std::back_inserter(raw_keys), + [&](auto& str) { return str.c_str(); }); + + std::vector raw_values; + std::transform( + values.begin(), + values.end(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(keys_ort_value, raw_keys.data(), raw_keys.size()), + ort_api); + RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(values_ort_value, raw_values.data(), raw_values.size()), + ort_api); + return S_OK; + } +}; + +template +HRESULT CreateMapValue(OnnxruntimeEngine* engine, IInspectable* map_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + auto ort_api = engine->UseOrtApi(); + auto map = CastToWinrtMap(map_insp); + std::vector shape = {static_cast(map.Size())}; + + winrt::com_ptr key_value; + RETURN_IF_FAILED(engine->CreateTensorValue(shape.data(), shape.size(), key_kind, key_value.put())); + auto keys_ort_value = static_cast(key_value.get())->UseOrtValue(); + + winrt::com_ptr value_value; + RETURN_IF_FAILED(engine->CreateTensorValue(shape.data(), shape.size(), value_kind, value_value.put())); + auto values_ort_value = static_cast(value_value.get())->UseOrtValue(); + + auto hr = FillMapTensors::Run(ort_api, map_insp, keys_ort_value, values_ort_value); + RETURN_IF_FAILED(hr); + + OrtValue* inputs[2] = {keys_ort_value, values_ort_value}; + + OrtValue* map_value; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateValue(inputs, 2, ONNXType::ONNX_TYPE_MAP, &map_value), + ort_api); + auto unique_map_ort_value = UniqueOrtValue(map_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, engine, std::move(unique_map_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +static auto GetMapValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Double, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::String, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Float, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::Double, _2); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) { + return std::bind(&CreateMapValue, engine, _1, winml::TensorKind::String, winml::TensorKind::String, _2); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::CreateMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + return GetMapValueCreator(this, key_kind, value_kind)(map, out); +} + +template +HRESULT CreateSequenceOfMapsValue(OnnxruntimeEngine* engine, IInspectable* sequence_insp, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + auto ort_api = engine->UseOrtApi(); + auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); + + std::vector> element_values; + for (auto element : sequence) { + winrt::com_ptr element_value; + engine->CreateMapValue(reinterpret_cast(winrt::get_abi(element)), key_kind, value_kind, element_value.put()); + element_values.push_back(element_value); + } + + std::vector element_ort_values; + std::transform(element_values.begin(), + element_values.end(), + std::back_inserter(element_ort_values), + [](auto value) { return static_cast(value.get())->UseOrtValue(); }); + + OrtValue* sequence_value; + RETURN_HR_IF_NOT_OK_MSG( + ort_api->CreateValue(element_ort_values.data(), element_ort_values.size(), + ONNXType::ONNX_TYPE_SEQUENCE, &sequence_value), + ort_api); + auto unique_sequence_ort_value = UniqueOrtValue(sequence_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, engine, std::move(unique_sequence_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; +} + +static auto GetSequenceOfMapsValueCreator(OnnxruntimeEngine* engine, winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateSequenceOfMapsValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Int64, _2); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&CreateSequenceOfMapsValue, engine, _1, winml::TensorKind::Int64, winml::TensorKind::Float, _2); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::CreateSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) { + RETURN_IF_FAILED(GetSequenceOfMapsValueCreator(this, key_kind, value_kind)(sequence, out)); + return S_OK; +} + +template +static HRESULT FillAbiSequence(IInspectable* sequence_insp, std::vector<::winrt::Windows::Foundation::IInspectable>& elements) { + using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; + using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; + auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); + for (auto element : elements) { + ::winrt::Windows::Foundation::Collections::IMap map_element; + element.as(map_element); + sequence.Append(map_element); + } + return S_OK; +} + +static auto GetAbiSequenceFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return &FillAbiSequence; + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return &FillAbiSequence; + } + THROW_HR(E_NOTIMPL); +} + +static winrt::Windows::Foundation::IInspectable CreateMap(winml::TensorKind key_kind, winml::TensorKind value_kind) { + winrt::Windows::Foundation::IInspectable map_insp; + if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + auto map = winrt::single_threaded_map(); + map.as(map_insp); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + auto map = winrt::single_threaded_map(); + map.as(map_insp); + } + + return map_insp; +} + +HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* sequence_value) { + auto ort_api = engine_factory_->UseOrtApi(); + auto onnxruntime_squence_value = static_cast(sequence_value); + auto ort_sequence_value = onnxruntime_squence_value->UseOrtValue(); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), ort_api); // This should not be freed as this owned by ort + + size_t num_elements; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueCount(ort_sequence_value, &num_elements), ort_api); + + // get the elements + std::vector<::winrt::Windows::Foundation::IInspectable> element_map_inspectables; + for (int index = 0; index < num_elements; index++) { + OrtValue* elements_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_sequence_value, index, ort_allocator, &elements_ort_value), ort_api); + auto unique_element_value = UniqueOrtValue(elements_ort_value, ort_api->ReleaseValue); + + winrt::com_ptr element_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(element_value.put(), this, std::move(unique_element_value), UniqueOrtAllocator(nullptr, nullptr))); + + ::winrt::Windows::Foundation::IInspectable map_inspectable = CreateMap(key_kind, value_kind); + RETURN_IF_FAILED(FillFromMapValue(reinterpret_cast(winrt::get_abi(map_inspectable)), key_kind, value_kind, element_value.get())); + element_map_inspectables.push_back(map_inspectable); + } + + GetAbiSequenceFiller(key_kind, value_kind)(sequence, element_map_inspectables); + return S_OK; +} + +HRESULT OnnxruntimeEngine::CreateOneInputAcrossDevices(const char* name, IValue* src, IValue** out) { + auto ort_api = engine_factory_->UseOrtApi(); + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + auto src_value = static_cast(src); + + bool is_set; + auto is_empty = SUCCEEDED(src_value->IsEmpty(&is_set)) && is_set; + auto is_tensor = SUCCEEDED(src_value->IsTensor(&is_set)) && is_set; + + if (is_tensor && !is_empty) { + int16_t source_location; + int16_t input_required_location; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(src_value->UseOrtValue(), &source_location), + ort_api); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetInputRequiredDeviceId(session_.get(), name, &input_required_location), + ort_api); + + if (source_location != input_required_location) { + OrtValue* dest_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionCopyOneInputAcrossDevices(session_.get(), name, + src_value->UseOrtValue(), &dest_ort_value), + ort_api); + auto unique_dest_ort_value = UniqueOrtValue(dest_ort_value, ort_api->ReleaseValue); + + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(unique_dest_ort_value), UniqueOrtAllocator(nullptr, nullptr))); + return S_OK; + } + } + + *out = src; + (*out)->AddRef(); + return S_OK; +} + +HRESULT OnnxruntimeEngine::Run(const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) { + auto ort_api = engine_factory_->UseOrtApi(); + + OrtRunOptions* run_options; + RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateRunOptions(&run_options), + ort_api); + auto unique_run_options = UniqueOrtRunOptions(run_options, ort_api->ReleaseRunOptions); + + std::vector input_ort_values; + std::transform( + inputs, + inputs + num_inputs, + std::back_inserter(input_ort_values), + [&](auto& input) { + auto input_value = static_cast(input); + return input_value->UseOrtValue(); + }); + + std::vector output_ort_values; + std::transform( + outputs, + outputs + num_outputs, + std::back_inserter(output_ort_values), + [&](auto& output) { + auto output_value = static_cast(output); + return output_value->UseOrtValue(); + }); + + RETURN_HR_IF_NOT_OK_MSG(ort_api->Run(session_.get(), + unique_run_options.get(), + input_names, + input_ort_values.data(), + num_inputs, + output_names, + num_outputs, + output_ort_values.data()), + ort_api); + + for (size_t index = 0; index < num_outputs; index++) { + auto output_value = static_cast(outputs[index]); + if (output_value->UseOrtValue() != output_ort_values[index]) { + RETURN_IF_FAILED(output_value->AssignOrtValue(output_ort_values[index])); + } + } + + return S_OK; +} + +template +HRESULT FillAbiMap(IInspectable* map_insp, size_t num_elements, void* keys_data, void* values_data) { + auto map = CastToWinrtMap(map_insp); + + auto keys = reinterpret_cast::ResourceType*>(keys_data); + auto values = reinterpret_cast::ResourceType*>(values_data); + + for (auto i = 0; i < num_elements; ++i) { + map.Insert( + ResourceTypeToCppwinrtType(keys[i]), + ResourceTypeToCppwinrtType(values[i])); + } + return S_OK; +} + +static auto GetAbiMapFiller(winml::TensorKind key_kind, winml::TensorKind value_kind) { + using namespace std::placeholders; + if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Int64) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Float) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::Double) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::Int64 && value_kind == winml::TensorKind::String) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Int64) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Double) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } else if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::String) { + return std::bind(&FillAbiMap, _1, _2, _3, _4); + } + + THROW_HR(E_NOTIMPL); +} + +HRESULT OnnxruntimeEngine::FillFromMapValue(IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* map_value) { + auto ort_api = engine_factory_->UseOrtApi(); + auto onnxruntime_map_value = static_cast(map_value); + auto ort_map_value = onnxruntime_map_value->UseOrtValue(); + + OrtAllocator* ort_allocator; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetAllocatorWithDefaultOptions(&ort_allocator), + ort_api); // This should not be freed as this owned by ort + + // get the keys + OrtValue* keys_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 0, ort_allocator, &keys_ort_value), + ort_api); + auto unique_keys_value = UniqueOrtValue(keys_ort_value, ort_api->ReleaseValue); + winrt::com_ptr keys_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(keys_value.put(), this, std::move(unique_keys_value), UniqueOrtAllocator(nullptr, nullptr))); + + // get the keys + OrtValue* values_ort_value = nullptr; + RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_map_value, 1, ort_allocator, &values_ort_value), + ort_api); + auto unique_values_value = UniqueOrtValue(values_ort_value, ort_api->ReleaseValue); + winrt::com_ptr values_value; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(values_value.put(), this, std::move(unique_values_value), UniqueOrtAllocator(nullptr, nullptr))); + + std::vector keys_shape; + keys_value->GetTensorShape(keys_shape); + + WinML::Resource keys_data; + RETURN_IF_FAILED(keys_value->GetResource(keys_data)); + WinML::Resource values_data; + RETURN_IF_FAILED(values_value->GetResource(values_data)); + + auto num_elements = ShapeSize(keys_shape.data(), keys_shape.size()); + GetAbiMapFiller(key_kind, value_kind)(map, num_elements, keys_data.get(), values_data.get()); + + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::RuntimeClassInitialize() { + ort_api_ = GetVersionedOrtApi(); + winml_adapter_api_ = GetVersionedWinmlAdapterApi(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::EnsureEnvironment() { + if (environment_ == nullptr) { + std::lock_guard lock(mutex_); + if (environment_ == nullptr) { + environment_ = PheonixSingleton(ort_api_); + } + } + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + + OrtModel* ort_model = nullptr; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateModelFromPath(model_path, len, &ort_model), + ort_api_); + + auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + OrtModel* ort_model = nullptr; + if (auto status = winml_adapter_api_->CreateModelFromData(data, size, &ort_model)) { + return E_INVALIDARG; + } + + auto model = UniqueOrtModel(ort_model, winml_adapter_api_->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(out, this, std::move(model))); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ Windows::AI::MachineLearning::IEngineBuilder** out) { + RETURN_IF_FAILED(EnsureEnvironment()); + Microsoft::WRL::ComPtr onnxruntime_engine_builder; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_builder, this)); + RETURN_IF_FAILED(onnxruntime_engine_builder.CopyTo(out)); + return S_OK; +} + +const OrtApi* OnnxruntimeEngineFactory::UseOrtApi() { + return ort_api_; +} + +const WinmlAdapterApi* OnnxruntimeEngineFactory::UseWinmlAdapterApi() { + return winml_adapter_api_; +} + +HRESULT OnnxruntimeEngineFactory::GetOrtEnvironment(OrtEnv** ort_env) { + RETURN_IF_FAILED(EnsureEnvironment()); + RETURN_IF_FAILED(environment_->GetOrtEnvironment(ort_env)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::EnableDebugOutput(bool is_enabled) { + RETURN_IF_FAILED(EnsureEnvironment()); + RETURN_IF_FAILED(environment_->EnableDebugOutput(is_enabled)); + return S_OK; +} + +HRESULT OnnxruntimeEngineFactory::CreateCustomRegistry(IMLOperatorRegistry** registry) { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api_->CreateCustomRegistry(registry), + ort_api_); + return S_OK; +} + +STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory) { + Microsoft::WRL::ComPtr onnxruntime_engine_factory; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_factory)); + RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory)); + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h new file mode 100644 index 0000000000000..6cb940c3a22a9 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -0,0 +1,143 @@ +#include "iengine.h" + +#include + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineBuilder; +class OnnxruntimeEngineFactory; +class OnnxruntimeEnvironment; +class OnnxruntimeModel; +class OnnxruntimeEngine; + +struct IOrtSessionBuilder; + +class OnnxruntimeValue : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IValue> { + public: + OnnxruntimeValue(); + ~OnnxruntimeValue(); + + HRESULT RuntimeClassInitialize(OnnxruntimeEngine* engine, UniqueOrtValue&& value, UniqueOrtAllocator&& allocator); + + STDMETHOD(IsEmpty) + (bool* out) override; + STDMETHOD(IsCpu) + (bool* out) override; + STDMETHOD(GetResource) + (WinML::Resource& resource) override; + STDMETHOD(IsTensor) + (bool* out) override; + STDMETHOD(IsOfTensorType) + (winml::TensorKind kind, bool* out) override; + STDMETHOD(GetTensorShape) + (std::vector& shape_vector) override; + STDMETHOD(IsOfMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) override; + STDMETHOD(IsOfVectorMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) override; + + HRESULT(SetParameter) + (IUnknown* param); + OrtValue* UseOrtValue(); + HRESULT AssignOrtValue(OrtValue* ptr); + + private: + Microsoft::WRL::ComPtr engine_; + Microsoft::WRL::ComPtr param_; + UniqueOrtValue value_; + UniqueOrtAllocator allocator_; +}; + +class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngine> { + public: + OnnxruntimeEngine(); + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, UniqueOrtSession&& session, IOrtSessionBuilder* session_builder); + + STDMETHOD(LoadModel) + (_In_ IModel* model) override; + STDMETHOD(Initialize) + () override; + STDMETHOD(RegisterGraphTransformers) + () override; + STDMETHOD(RegisterCustomRegistry) + (IMLOperatorRegistry* registry) override; + STDMETHOD(EndProfiling) + () override; + STDMETHOD(StartProfiling) + () override; + STDMETHOD(FlushContext) + () override; + STDMETHOD(TrimUploadHeap) + () override; + STDMETHOD(ReleaseCompletedReferences) + () override; + STDMETHOD(Sync) + () override; + STDMETHOD(CreateTensorValue) + (const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateTensorValueFromExternalD3DResource) + (ID3D12Resource* resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateTensorValueFromExternalBuffer) + (void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) override; + STDMETHOD(CreateStringTensorValueFromDataWithCopy) + (const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) override; + STDMETHOD(CreateNullValue) + (_Out_ IValue** out) override; + STDMETHOD(CreateMapValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) override; + STDMETHOD(CreateSequenceOfMapsValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue** out) override; + STDMETHOD(CreateOneInputAcrossDevices) + (const char* name, IValue* src, IValue** dest) override; + STDMETHOD(CopyValueAcrossDevices) + (IValue* src, IValue* dest) override; + STDMETHOD(Run) + (const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) override; + STDMETHOD(FillFromMapValue) + (IInspectable* map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* value) override; + STDMETHOD(FillSequenceOfMapsValue) + (IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* value) override; + + OrtSession* UseOrtSession(); + const OrtApi* UseOrtApi(); + OnnxruntimeEngineFactory* GetEngineFactory(); + + private: + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr session_builder_; + UniqueOrtSession session_; +}; + +class OnnxruntimeEngineFactory : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngineFactory> { + public: + HRESULT RuntimeClassInitialize(); + STDMETHOD(CreateModel) + (_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) override; + STDMETHOD(CreateModel) + (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) override; + STDMETHOD(CreateEngineBuilder) + (IEngineBuilder** engine_builder) override; + STDMETHOD(EnableDebugOutput) + (bool is_enabled) override; + STDMETHOD(CreateCustomRegistry) + (_Out_ IMLOperatorRegistry** registry) override; + + const OrtApi* UseOrtApi(); + const WinmlAdapterApi* UseWinmlAdapterApi(); + HRESULT EnsureEnvironment(); + HRESULT GetOrtEnvironment(_Out_ OrtEnv** ort_env); + + private: + const OrtApi* ort_api_ = nullptr; + const WinmlAdapterApi* winml_adapter_api_ = nullptr; + std::shared_ptr environment_; + std::mutex mutex_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp new file mode 100644 index 0000000000000..ecfb6561657c9 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp @@ -0,0 +1,72 @@ +#include "pch.h" + +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeEngineBuilder.h" +#include "OnnxruntimeCpuSessionBuilder.h" + +#ifdef USE_DML +#include "OnnxruntimeDmlSessionBuilder.h" +#endif + +#include "OnnxruntimeErrors.h" +using namespace WinML; + +HRESULT OnnxruntimeEngineBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { + engine_factory_ = engine_factory; + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(Windows::AI::MachineLearning::IEngine** out) { + auto ort_api = engine_factory_->UseOrtApi(); + + Microsoft::WRL::ComPtr onnxruntime_session_builder; + + if (device_ == nullptr) { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_session_builder, engine_factory_.Get())); + } else { +#ifdef USE_DML + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_session_builder, engine_factory_.Get(), device_.Get(), queue_.Get())); +#endif + } + + OrtSessionOptions* ort_options; + RETURN_IF_FAILED(onnxruntime_session_builder->CreateSessionOptions(&ort_options)); + auto session_options = UniqueOrtSessionOptions(ort_options, ort_api->ReleaseSessionOptions); + + if (batch_size_override_.has_value()) { + constexpr const char* DATA_BATCH = "DATA_BATCH"; + RETURN_HR_IF_NOT_OK_MSG(ort_api->AddFreeDimensionOverride(session_options.get(), DATA_BATCH, batch_size_override_.value()), + ort_api); + } + + OrtSession* ort_session = nullptr; + onnxruntime_session_builder->CreateSession(session_options.get(), &ort_session); + auto session = UniqueOrtSession(ort_session, ort_api->ReleaseSession); + + Microsoft::WRL::ComPtr onnxruntime_engine; + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine, + engine_factory_.Get(), std::move(session), onnxruntime_session_builder.Get())); + RETURN_IF_FAILED(onnxruntime_engine.CopyTo(out)); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::GetD3D12Device(ID3D12Device** device) { + *device = device_.Get(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::SetD3D12Resources(ID3D12Device* device, ID3D12CommandQueue* queue) { + device_ = device; + queue_ = queue; + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::GetID3D12CommandQueue(ID3D12CommandQueue** queue) { + *queue = queue_.Get(); + return S_OK; +} + +STDMETHODIMP OnnxruntimeEngineBuilder::SetBatchSizeOverride(uint32_t batch_size_override) { + batch_size_override_ = batch_size_override; + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h new file mode 100644 index 0000000000000..34e68ae742ba0 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h @@ -0,0 +1,33 @@ +#include "iengine.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineBuilder : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IEngineBuilder> { + public: + HRESULT RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine); + + STDMETHOD(SetD3D12Resources) + (ID3D12Device* device, ID3D12CommandQueue* queue); + + STDMETHOD(GetD3D12Device) + (_Outptr_ ID3D12Device** device); + + STDMETHOD(GetID3D12CommandQueue) + (_Outptr_ ID3D12CommandQueue** queue); + + STDMETHOD(SetBatchSizeOverride) + (uint32_t batch_size_override); + + STDMETHOD(CreateEngine) + (_Outptr_ IEngine** out); + + private: + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr device_ = nullptr; + Microsoft::WRL::ComPtr queue_ = nullptr; + std::optional batch_size_override_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp new file mode 100644 index 0000000000000..fbd5003b6d007 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" +#include "OnnxruntimeEnvironment.h" +#include "OnnxruntimeErrors.h" +#include "core/platform/windows/TraceLoggingConfig.h" +#include + +using namespace Windows::AI ::MachineLearning; + +static bool debug_output_ = false; + +static void WinmlOrtLoggingCallback(void* param, OrtLoggingLevel severity, const char* category, + const char* logger_id, const char* code_location, const char* message) { + UNREFERENCED_PARAMETER(param); + UNREFERENCED_PARAMETER(logger_id); + // ORT Fatal and Error Messages are logged as Telemetry, rest are non-telemetry. + switch (severity) { + case OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL: //Telemetry + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_CRITICAL), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR: //Telemetry + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_WARNING), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + break; + case OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE: + __fallthrough; //Default is Verbose too. + default: + TraceLoggingWrite( + winml_trace_logging_provider, + "WinMLLogSink", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(category), + TraceLoggingUInt32((UINT32)severity), + TraceLoggingString(message), + TraceLoggingString(code_location)); + } + + if (debug_output_) { + OutputDebugStringA((std::string(message) + "\r\n").c_str()); + } +} + +static void WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) { + if (profiler_record->category_ == OrtProfilerEventCategory::NODE_EVENT) { + TraceLoggingWrite( + winml_trace_logging_provider, + "OnnxRuntimeProfiling", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(profiler_record->category_name_, "Category"), + TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), + TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), + TraceLoggingString(profiler_record->event_name_, "Event Name"), + TraceLoggingInt32(profiler_record->process_id_, "Process ID"), + TraceLoggingInt32(profiler_record->thread_id_, "Thread ID"), + TraceLoggingString(profiler_record->op_name_, "Operator Name"), + TraceLoggingString(profiler_record->execution_provider_, "Execution Provider")); + } else { + TraceLoggingWrite( + winml_trace_logging_provider, + "OnnxRuntimeProfiling", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingOpcode(EVENT_TRACE_TYPE_INFO), + TraceLoggingString(profiler_record->category_name_, "Category"), + TraceLoggingInt64(profiler_record->duration_, "Duration (us)"), + TraceLoggingInt64(profiler_record->time_span_, "Time Stamp (us)"), + TraceLoggingString(profiler_record->event_name_, "Event Name"), + TraceLoggingInt32(profiler_record->process_id_, "Process ID"), + TraceLoggingInt32(profiler_record->thread_id_, "Thread ID")); + } +} + +OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_(nullptr, nullptr) { + OrtEnv* ort_env = nullptr; + THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), + ort_api); + ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv); + + // Configure the environment with the winml logger + auto winml_adapter_api = OrtGetWinMLAdapter(ort_api); + THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(), + &WinmlOrtLoggingCallback, &WinmlOrtProfileEventCallback, nullptr, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env), + ort_api); + + THROW_IF_NOT_OK_MSG(winml_adapter_api->OverrideSchema(), ort_api); +} + +HRESULT OnnxruntimeEnvironment::GetOrtEnvironment(_Out_ OrtEnv** ort_env) { + *ort_env = ort_env_.get(); + return S_OK; +} + +HRESULT OnnxruntimeEnvironment::EnableDebugOutput(bool is_enabled) { + debug_output_ = is_enabled; + return S_OK; +} \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h new file mode 100644 index 0000000000000..c0e01f1989b99 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#pragma warning(push) +#pragma warning(disable : 4505) + +namespace Windows::AI ::MachineLearning { + +using UniqueOrtEnv = std::unique_ptr; + +class OnnxruntimeEnvironment { + public: + OnnxruntimeEnvironment(const OrtApi* ort_api); + + HRESULT GetOrtEnvironment(_Out_ OrtEnv** ert_env); + HRESULT EnableDebugOutput(bool is_enabled); + + private: + UniqueOrtEnv ort_env_; +}; + +} // namespace Windows::AI::MachineLearning + +#pragma warning(pop) \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeErrors.h b/winml/lib/Api.Ort/OnnxruntimeErrors.h new file mode 100644 index 0000000000000..3f9fd88b783d6 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeErrors.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once +#include "pch.h" +#include "core/providers/winml/winml_provider_factory.h" + +#ifdef _WIN32 +inline HRESULT OrtErrorCodeToHRESULT(OrtErrorCode status) noexcept { + switch (status) { + case OrtErrorCode::ORT_OK: + return S_OK; + case OrtErrorCode::ORT_FAIL: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_ARGUMENT: + return E_INVALIDARG; + case OrtErrorCode::ORT_NO_SUCHFILE: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_NO_MODEL: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case OrtErrorCode::ORT_ENGINE_ERROR: + return E_FAIL; + case OrtErrorCode::ORT_RUNTIME_EXCEPTION: + return E_FAIL; + case OrtErrorCode::ORT_INVALID_PROTOBUF: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_MODEL_LOADED: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case OrtErrorCode::ORT_NOT_IMPLEMENTED: + return E_NOTIMPL; + case OrtErrorCode::ORT_INVALID_GRAPH: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case OrtErrorCode::ORT_EP_FAIL: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +#define RETURN_HR_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + RETURN_HR_MSG(hresult, \ + error_message); \ + } \ + } while (0) + +#define THROW_IF_NOT_OK_MSG(status, ort_api) \ + do { \ + auto _status = status; \ + if (_status) { \ + auto error_code = ort_api->GetErrorCode(_status); \ + auto error_message = ort_api->GetErrorMessage(_status); \ + HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ + telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ + winrt::hstring errorMessage(WinML::Strings::HStringFromUTF8(error_message)); \ + throw winrt::hresult_error(hresult, errorMessage); \ + } \ + } while (0) diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp new file mode 100644 index 0000000000000..bc782fbd17343 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "pch.h" +#include "OnnxruntimeModel.h" +#include "core/platform/windows/TraceLoggingConfig.h" +#include + +#include "OnnxruntimeDescriptorConverter.h" +#include "OnnxruntimeEngine.h" +#include "OnnxruntimeErrors.h" + +using namespace Windows::AI::MachineLearning; + +struct winml_adapter_api_model_feature_helper { + decltype(WinmlAdapterApi::ModelGetInputCount) GetCount; + decltype(WinmlAdapterApi::ModelGetInputName) GetName; + decltype(WinmlAdapterApi::ModelGetInputDescription) GetDescription; + decltype(WinmlAdapterApi::ModelGetInputTypeInfo) GetTypeInfo; +}; + +HRESULT CreateFeatureDescriptors( + OnnxruntimeEngineFactory* engine_factory, + const winml_adapter_api_model_feature_helper* feature_helpers, + OrtModel* ort_model, + std::vector& descriptors) { + const auto ort_api = engine_factory->UseOrtApi(); + size_t count; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetCount(ort_model, &count), + engine_factory->UseOrtApi()); + + for (size_t i = 0; i < count; i++) { + OnnxruntimeValueInfoWrapper descriptor; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetName(ort_model, i, &descriptor.name_, &descriptor.name_length_), + engine_factory->UseOrtApi()); + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetDescription(ort_model, i, &descriptor.description_, &descriptor.description_length_), + engine_factory->UseOrtApi()); + + OrtTypeInfo* type_info; + RETURN_HR_IF_NOT_OK_MSG(feature_helpers->GetTypeInfo(ort_model, i, &type_info), + engine_factory->UseOrtApi()); + + descriptor.type_info_ = UniqueOrtTypeInfo(type_info, ort_api->ReleaseTypeInfo); + + descriptors.push_back(std::move(descriptor)); + } + return S_OK; +} + +HRESULT ModelInfo::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, OrtModel* ort_model) { + RETURN_HR_IF_NULL(E_INVALIDARG, ort_model); + + const auto winml_adapter_api = engine_factory->UseWinmlAdapterApi(); + + // Get Metadata + size_t count; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadataCount(ort_model, &count), + engine_factory->UseOrtApi()); + + const char* metadata_key; + size_t metadata_key_len; + const char* metadata_value; + size_t metadata_value_len; + for (size_t i = 0; i < count; i++) { + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetMetadata(ort_model, i, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len), + engine_factory->UseOrtApi()); + + model_metadata_.insert_or_assign( + std::string(metadata_key, metadata_key_len), + std::string(metadata_value, metadata_value_len)); + } + + WinML::OnnxruntimeDescriptorConverter converter(engine_factory, model_metadata_); + + static const winml_adapter_api_model_feature_helper input_helpers = { + winml_adapter_api->ModelGetInputCount, + winml_adapter_api->ModelGetInputName, + winml_adapter_api->ModelGetInputDescription, + winml_adapter_api->ModelGetInputTypeInfo}; + + // Create inputs + std::vector inputs; + RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &input_helpers, ort_model, inputs)); + input_features_ = converter.ConvertToLearningModelDescriptors(inputs); + + // Create outputs + static const winml_adapter_api_model_feature_helper output_helpers = { + winml_adapter_api->ModelGetOutputCount, + winml_adapter_api->ModelGetOutputName, + winml_adapter_api->ModelGetOutputDescription, + winml_adapter_api->ModelGetOutputTypeInfo}; + + std::vector outputs; + RETURN_IF_FAILED(CreateFeatureDescriptors(engine_factory, &output_helpers, ort_model, outputs)); + output_features_ = converter.ConvertToLearningModelDescriptors(outputs); + + const char* out; + size_t len; + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetAuthor(ort_model, &out, &len), + engine_factory->UseOrtApi()); + author_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetName(ort_model, &out, &len), + engine_factory->UseOrtApi()); + name_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDomain(ort_model, &out, &len), + engine_factory->UseOrtApi()); + domain_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetDescription(ort_model, &out, &len), + engine_factory->UseOrtApi()); + description_ = std::string(out, len); + + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelGetVersion(ort_model, &version_), + engine_factory->UseOrtApi()); + + return S_OK; +} + +STDMETHODIMP ModelInfo::GetAuthor(const char** out, size_t* len) { + *out = author_.c_str(); + *len = author_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetName(const char** out, size_t* len) { + *out = name_.c_str(); + *len = name_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetDomain(const char** out, size_t* len) { + *out = domain_.c_str(); + *len = domain_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetDescription(const char** out, size_t* len) { + *out = description_.c_str(); + *len = description_.size(); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetVersion(int64_t* out) { + *out = version_; + return S_OK; +} + +STDMETHODIMP ModelInfo::GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView** metadata) { + std::unordered_map map_copy; + for (auto& pair : model_metadata_) { + auto metadata_key = WinML::Strings::HStringFromUTF8(pair.first); + auto metadata_value = WinML::Strings::HStringFromUTF8(pair.second); + map_copy.emplace(std::move(metadata_key), std::move(metadata_value)); + } + auto map = winrt::single_threaded_map(std::move(map_copy)); + winrt::copy_to_abi(map, *(void**)metadata); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetInputFeatures(ABI::Windows::Foundation::Collections::IVectorView** features) { + *features = nullptr; + winrt::copy_to_abi(input_features_.GetView(), *(void**)features); + return S_OK; +} + +STDMETHODIMP ModelInfo::GetOutputFeatures(ABI::Windows::Foundation::Collections::IVectorView** features) { + *features = nullptr; + winrt::copy_to_abi(output_features_.GetView(), *(void**)features); + return S_OK; +} + +OnnruntimeModel::OnnruntimeModel() : ort_model_(nullptr, nullptr) { +} + +STDMETHODIMP OnnruntimeModel::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, UniqueOrtModel&& ort_model) { + RETURN_HR_IF_NULL(E_INVALIDARG, ort_model); + + engine_factory_ = engine_factory; + ort_model_ = std::move(ort_model); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::GetModelInfo(IModelInfo** info) { + if (info_ == nullptr) { + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&info_, engine_factory_.Get(), ort_model_.get())); + } + + info_.CopyTo(info); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::ModelEnsureNoFloat16() { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->ModelEnsureNoFloat16(ort_model_.get()), + engine_factory_->UseOrtApi()); + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::CloneModel(IModel** copy) { + auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi(); + + OrtModel* ort_model_copy; + RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->CloneModel(ort_model_.get(), &ort_model_copy), + engine_factory_->UseOrtApi()); + + auto model = UniqueOrtModel(ort_model_copy, winml_adapter_api->ReleaseModel); + RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(copy, engine_factory_.Get(), std::move(model))); + + return S_OK; +} + +STDMETHODIMP OnnruntimeModel::DetachOrtModel(OrtModel** model) { + *model = ort_model_.release(); + return S_OK; +} diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.h b/winml/lib/Api.Ort/OnnxruntimeModel.h new file mode 100644 index 0000000000000..1be587cfc8b48 --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeModel.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "iengine.h" + +namespace Windows::AI::MachineLearning { + +class OnnxruntimeEngineFactory; + +// The IOrtSessionBuilder offers an abstraction over the creation of +// InferenceSession, that enables the creation of the session based on a device (CPU/DML). +MIDL_INTERFACE("92679cbf-7a9d-48bb-b97f-ef9fb447ce8e") +IOnnxruntimeModel : IUnknown { + virtual HRESULT STDMETHODCALLTYPE DetachOrtModel(OrtModel * *model) PURE; +}; + +class ModelInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModelInfo> { + public: + HRESULT RuntimeClassInitialize(_In_ OnnxruntimeEngineFactory* engine, _In_ OrtModel* ort_model); + + STDMETHOD(GetAuthor) + (const char** out, size_t* len); + STDMETHOD(GetName) + (const char** out, size_t* len); + STDMETHOD(GetDomain) + (const char** out, size_t* len); + STDMETHOD(GetDescription) + (const char** out, size_t* len); + STDMETHOD(GetVersion) + (int64_t* out); + STDMETHOD(GetModelMetadata) + (ABI::Windows::Foundation::Collections::IMapView** metadata); + STDMETHOD(GetInputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView** features); + STDMETHOD(GetOutputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView** features); + + private: + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::unordered_map model_metadata_; + wfc::IVector input_features_; + wfc::IVector output_features_; +}; + +class OnnruntimeModel : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModel, + IOnnxruntimeModel> { + public: + OnnruntimeModel(); + + HRESULT RuntimeClassInitialize(OnnxruntimeEngineFactory* engine, UniqueOrtModel&& ort_model); + + STDMETHOD(GetModelInfo) + (IModelInfo** info); + STDMETHOD(ModelEnsureNoFloat16) + (); + STDMETHOD(CloneModel) + (IModel** copy); + STDMETHOD(DetachOrtModel) + (OrtModel** model); + + private: + UniqueOrtModel ort_model_; + + Microsoft::WRL::ComPtr engine_factory_; + Microsoft::WRL::ComPtr info_; + + std::optional> metadata_cache_; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h new file mode 100644 index 0000000000000..372da3c792c9f --- /dev/null +++ b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning { + +// The IOrtSessionBuilder offers an abstraction over the creation of +// InferenceSession, that enables the creation of the session based on a device (CPU/DML). +MIDL_INTERFACE("2746f03a-7e08-4564-b5d0-c670fef116ee") +IOrtSessionBuilder : IUnknown { + virtual HRESULT STDMETHODCALLTYPE CreateSessionOptions( + OrtSessionOptions * *options) = 0; + + virtual HRESULT STDMETHODCALLTYPE CreateSession( + OrtSessionOptions * options, + OrtSession * *session) = 0; + + virtual HRESULT STDMETHODCALLTYPE Initialize( + OrtSession * session) = 0; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h new file mode 100644 index 0000000000000..120965f4a7e80 --- /dev/null +++ b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "iengine.h" + +STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory); \ No newline at end of file diff --git a/winml/lib/Api.Ort/pch.h b/winml/lib/Api.Ort/pch.h new file mode 100644 index 0000000000000..e41ad60623e9b --- /dev/null +++ b/winml/lib/Api.Ort/pch.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "winrt_headers.h" + +#include "core/providers/winml/winml_provider_factory.h" +#include "adapter/winml_adapter_c_api.h" + +using UniqueOrtModel = std::unique_ptr; +using UniqueOrtSessionOptions = std::unique_ptr; +using UniqueOrtSession = std::unique_ptr; +using UniqueOrtExecutionProvider = std::unique_ptr; +using UniqueOrtValue = std::unique_ptr; +using UniqueOrtMemoryInfo = std::unique_ptr; +using UniqueOrtTypeInfo = std::unique_ptr; +using UniqueOrtTensorTypeAndShapeInfo = std::unique_ptr; +using UniqueOrtAllocator = std::unique_ptr; +using UniqueOrtRunOptions = std::unique_ptr; diff --git a/winml/lib/Api/FeatureValues.h b/winml/lib/Api/FeatureValues.h index b621238768b84..637a4a2c61a74 100644 --- a/winml/lib/Api/FeatureValues.h +++ b/winml/lib/Api/FeatureValues.h @@ -58,7 +58,7 @@ \ type(std::vector const& shape) : Base(shape){}; \ \ - type(std::vector const& shape, ID3D12Resource* pResource, UINT64 resource_width) : Base(shape, pResource, resource_width){}; \ + type(std::vector const& shape, ID3D12Resource* pResource) : Base(shape, pResource){}; \ }; \ } \ namespace winrt::Windows::AI::MachineLearning::factory_implementation { \ @@ -85,7 +85,7 @@ CREATE_TENSOR(TensorUInt32Bit, uint32_t, uint32_t) CREATE_TENSOR(TensorInt32Bit, int32_t, int32_t) CREATE_TENSOR(TensorUInt64Bit, uint64_t, uint64_t) CREATE_TENSOR(TensorInt64Bit, int64_t, int64_t) -CREATE_TENSOR(TensorFloat16Bit, onnxruntime::MLFloat16, float) +CREATE_TENSOR(TensorFloat16Bit, WinML::Half, float) #pragma warning(push) #pragma warning(disable : 4702) // Unreachable code (one of TensorBase's constructor unconditionally throws for diff --git a/winml/lib/Api/ImageFeatureDescriptor.cpp b/winml/lib/Api/ImageFeatureDescriptor.cpp index ac53db7d6e7df..e2f1e70000512 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.cpp +++ b/winml/lib/Api/ImageFeatureDescriptor.cpp @@ -11,9 +11,9 @@ namespace winrt::Windows::AI::MachineLearning::implementation { ImageFeatureDescriptor::ImageFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, wgi::BitmapPixelFormat pixel_format, wgi::BitmapAlphaMode alpha_mode, uint32_t width, @@ -32,28 +32,6 @@ ImageFeatureDescriptor::ImageFeatureDescriptor( color_space_gamma_(color_space_gamma) { } -ImageFeatureDescriptor::ImageFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& TensorKind, - array_view Shape, - Windows::Graphics::Imaging::BitmapPixelFormat const& BitmapPixelFormat, - Windows::Graphics::Imaging::BitmapAlphaMode const& BitmapAlphaMode, - uint32_t Width, - uint32_t Height) : name_(Name), - description_(Description), - tensor_kind_(TensorKind), - shape_(Shape.begin(), Shape.end()), - is_required_(IsRequired), - pixel_format_(BitmapPixelFormat), - alpha_mode_(BitmapAlphaMode), - width_(Width), - height_(Height), - nominal_pixel_range_(ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255), - color_space_gamma_(ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB) { -} - wgi::BitmapPixelFormat ImageFeatureDescriptor::BitmapPixelFormat() try { return pixel_format_; diff --git a/winml/lib/Api/ImageFeatureDescriptor.h b/winml/lib/Api/ImageFeatureDescriptor.h index 336e4230dd9b4..54f1f265b3724 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.h +++ b/winml/lib/Api/ImageFeatureDescriptor.h @@ -24,9 +24,9 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, wgi::BitmapPixelFormat pixelformat, wgi::BitmapAlphaMode alphamode, uint32_t width, @@ -34,17 +34,6 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageNominalPixelRange nominalPixelRange, ImageColorSpaceGamma colorSpaceGamma); - ImageFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& TensorKind, - array_view Shape, - Windows::Graphics::Imaging::BitmapPixelFormat const& BitmapPixelFormat, - Windows::Graphics::Imaging::BitmapAlphaMode const& BitmapAlphaMode, - uint32_t Width, - uint32_t Height); - wgi::BitmapPixelFormat BitmapPixelFormat(); @@ -104,10 +93,4 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageNominalPixelRange nominal_pixel_range_; ImageColorSpaceGamma color_space_gamma_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct ImageFeatureDescriptor : ImageFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/ImageFeatureValue.cpp b/winml/lib/Api/ImageFeatureValue.cpp index f24d41489130f..983ee677471de 100644 --- a/winml/lib/Api/ImageFeatureValue.cpp +++ b/winml/lib/Api/ImageFeatureValue.cpp @@ -20,10 +20,6 @@ #include "D3DDeviceCache.h" #include "TensorFeatureDescriptor.h" -// Uncomment to enable DEBUG_IMAGE_TENSOR_RESOURCE and -// allow debugging the content of the resource -//#define DEBUG_IMAGE_TENSOR_RESOURCE - using namespace WinML; using namespace winrt::Windows::Graphics::Imaging; using namespace winrt::Windows::Graphics::DirectX::Direct3D11; @@ -38,87 +34,6 @@ struct ImageFeatureValue::ImageResourceMetadata { ::Windows::AI::MachineLearning::Internal::ImageTensorDescription TensorDescriptor; }; -#ifdef ENABLE_IMAGE_FEATURE_VALUE_TENSOR_DUMP -static void DumpResourceToCPU( - ID3D12Resource* pResource, - com_ptr spSession, - ImageTensorDescription tensorDescriptor, - ::Windows::AI::MachineLearning::Internal::TensorToVideoFrameConverter* tensorToImageConverter) { - auto spDevice = spSession->Device().as(); - auto spD3DDevice = spDevice->GetD3DDevice(); - auto spCommandQueue = spDevice->GetDeviceQueue(); - auto pProvider = spSession->GetExecutionProvider(); - - UINT64 bufferbytesize = pResource->GetDesc().Width; - - Dml::FlushContext(pProvider); - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_READBACK, - D3D12_CPU_PAGE_PROPERTY_UNKNOWN, - D3D12_MEMORY_POOL_UNKNOWN, - 0, - 0}; - D3D12_RESOURCE_DESC resourceDesc = { - D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - bufferbytesize, - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - {1, 0}, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_NONE}; - - ID3D12Resource* pCPUResource = nullptr; - spD3DDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_COPY_DEST, - nullptr, - IID_PPV_ARGS(&pCPUResource)); - - { - ScopedCommandList scopedCommandList(spSession); - // Record command list copy action - scopedCommandList.get()->CopyResource(pCPUResource, pResource); - scopedCommandList.get()->Close(); - ID3D12CommandList* pCommandLists[] = {scopedCommandList.get()}; - spCommandQueue->ExecuteCommandLists(ARRAYSIZE(pCommandLists), pCommandLists); - - // TODO: Do we need to set a fence here and wait for completion before - // reading the resource in cpu memory? - } - - D3D12_RANGE range = {0, static_cast(bufferbytesize)}; - - void* pData = nullptr; - pCPUResource->Map(0, &range, reinterpret_cast(&pData)); - - range.End = 0; - - DebugBreak(); - - SoftwareBitmap bitmap(BitmapPixelFormat::Bgra8, 720, 720); - Windows::Media::VideoFrame frame = Windows::Media::VideoFrame::CreateWithSoftwareBitmap(bitmap); - tensorToImageConverter->SoftwareTensorToVideoFrame( - spSession.as(), - reinterpret_cast(pData), - tensorDescriptor, - frame); - - auto folder = Windows::Storage::StorageFolder::GetFolderFromPathAsync(L"C:\\").get(); - auto imagefile = folder.CreateFileAsync(L"out.png", Windows::Storage::CreationCollisionOption::ReplaceExisting).get(); - auto stream = imagefile.OpenAsync(Windows::Storage::FileAccessMode::ReadWrite).get(); - auto encoder = BitmapEncoder::CreateAsync(BitmapEncoder::JpegEncoderId(), stream).get(); - encoder.SetSoftwareBitmap(frame.SoftwareBitmap()); - encoder.FlushAsync(); - pResource->Unmap(0, &range); -} -#endif - Windows::AI::MachineLearning::ImageFeatureValue ImageFeatureValue::Create( uint32_t batchSize, BitmapPixelFormat format, @@ -329,13 +244,12 @@ static void CPUTensorize( std::vector bounds, ImageTensorDescription tensorDescriptor, com_ptr spSession, - void* pResource, + BYTE* resource, unsigned int singleFrameBufferSize) { // Tensorize video frames one by one without extra copy. - BYTE* tempPResource = reinterpret_cast(pResource); for (uint32_t batchIdx = 0; batchIdx < videoFrames.Size(); ++batchIdx) { - CPUTensorize(videoFrames.GetAt(batchIdx), bounds[batchIdx], tensorDescriptor, spSession, tempPResource); - tempPResource += singleFrameBufferSize; + CPUTensorize(videoFrames.GetAt(batchIdx), bounds[batchIdx], tensorDescriptor, spSession, resource); + resource += singleFrameBufferSize; } } @@ -344,15 +258,8 @@ static void GPUTensorize( std::vector bounds, ImageTensorDescription tensorDescriptor, com_ptr spSession, - void* pAllocatedResource, + ID3D12Resource* d3dResource, WinML::BindingContext& context) { - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - - auto d3dResource = - adapter->GetD3D12ResourceFromAllocation( - spSession->GetExecutionProvider(), - pAllocatedResource); auto spDevice = spSession->Device().as(); ConverterResourceDescription descriptor = {}; @@ -386,9 +293,6 @@ static void GPUTensorize( context.converter = pooledConverter; } } -#ifdef DEBUG_IMAGE_TENSOR_RESOURCE - DumpResourceToCPU(d3dResource, spSession, tensorDescriptor); -#endif } std::optional ImageFeatureValue::GetInputMetadata(const WinML::BindingContext& context) { @@ -490,7 +394,7 @@ std::optional ImageFeatureValue::GetIn return ImageResourceMetadata{bounds, imageTensorDescriptor}; } -HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) try { +HRESULT ImageFeatureValue::GetValue(WinML::BindingContext& context, IValue** out) try { FAIL_FAST_IF(!(std::all_of(m_widths.begin(), m_widths.end(), [](int i) { return i != 0; }))); FAIL_FAST_IF(!(std::all_of(m_heights.begin(), m_heights.end(), [](int i) { return i != 0; }))); @@ -502,27 +406,19 @@ HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue* // Get the session auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); - auto provider = spSession->GetExecutionProvider(); - - // and the adapter - if (!m_adapter) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(m_adapter.put())); - } + auto engine = spSession->GetEngine(); // create the OrtValue - Ort::Allocator dml_allocator(m_adapter.get(), nullptr); - WINML_THROW_IF_FAILED(m_adapter->GetProviderAllocator(provider, dml_allocator.put())); - - // create the OrtValue as a tensor letting ort know that we own the data buffer - Ort::Value ort_tensor = Ort::Value::CreateTensor( - dml_allocator, - &(resourceMetadata.TensorDescriptor.sizes[0]), + winrt::com_ptr value; + RETURN_IF_FAILED(engine->CreateTensorValue( + resourceMetadata.TensorDescriptor.sizes, sizeof(resourceMetadata.TensorDescriptor.sizes) / sizeof(resourceMetadata.TensorDescriptor.sizes[0]), - (resourceMetadata.TensorDescriptor.dataType == kImageTensorDataTypeFloat32) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); + resourceMetadata.TensorDescriptor.dataType == kImageTensorDataTypeFloat32 ? winml::TensorKind::Float : winml::TensorKind::Float16, + value.put())); // Get the tensor raw data - void* pAllocatedResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_tensor, &pAllocatedResource)); + WinML::Resource void_resource; + RETURN_IF_FAILED(value->GetResource(void_resource)); if (context.type == BindingType::kInput) { // Only tensorize inputs @@ -530,15 +426,15 @@ HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue* auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize; auto singleFrameBufferSize = bufferByteSize / m_batchSize; if (spDevice->IsCpuDevice()) { - CPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, pAllocatedResource, static_cast(singleFrameBufferSize)); - } - else { - GPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, pAllocatedResource, context); + auto resource = reinterpret_cast(void_resource.get()); + CPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, resource, static_cast(singleFrameBufferSize)); + } else { + auto resource = reinterpret_cast(void_resource.get()); + GPUTensorize(m_videoFrames, resourceMetadata.Bounds, resourceMetadata.TensorDescriptor, spSession, resource, context); } } - *ort_value = ort_tensor.release(); - *ort_allocator = dml_allocator.release(); + *out = value.detach(); return S_OK; } WINML_CATCH_ALL_COM @@ -549,18 +445,14 @@ HRESULT ImageFeatureValue::IsPlaceholder(bool* pIsPlaceHolder) { return S_OK; } -HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, OrtValue* ort_value) try { +HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, IValue* value) try { // Get the device auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); - if (!m_adapter) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(m_adapter.put())); - } - // Get the output tensor raw data - void* pAllocatedResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_value, &pAllocatedResource)); + WinML::Resource void_resource; + RETURN_IF_FAILED(value->GetResource(void_resource)); // Get the run context auto metadata = GetInputMetadata(context); @@ -570,36 +462,30 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, Ort descriptor.width = static_cast(resourceMetadata.TensorDescriptor.sizes[3]); descriptor.height = static_cast(resourceMetadata.TensorDescriptor.sizes[2]); - Ort::MemoryInfo memory_info(nullptr); - m_adapter->GetValueMemoryInfo(ort_value, memory_info.put()); - - if (!strcmp(memory_info.Name(), onnxruntime::CPU) || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUOutput || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUInput) { + bool out; + if (SUCCEEDED(value->IsCpu(&out)) && out) { descriptor.pixel_format = static_cast(BitmapPixelFormat::Bgra8); descriptor.luid = {}; // Converted image on CPU auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); - auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast< int64_t>(1), std::multiplies()); + auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast(1), std::multiplies()); auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize / m_batchSize; - BYTE* tempPAllocatedResource = reinterpret_cast(pAllocatedResource); + BYTE* resource = reinterpret_cast(void_resource.get()); for (uint32_t batchIdx = 0; batchIdx < m_batchSize; ++batchIdx) { // Convert Software Tensor to VideoFrame one by one based on the buffer size. auto videoFrame = m_videoFrames.GetAt(batchIdx); - pooledConverter->Get()->Detensorizer->SoftwareTensorToVideoFrame(context.session, tempPAllocatedResource, resourceMetadata.TensorDescriptor, videoFrame); - tempPAllocatedResource += bufferByteSize; + pooledConverter->Get()->Detensorizer->SoftwareTensorToVideoFrame(context.session, resource, resourceMetadata.TensorDescriptor, videoFrame); + resource += bufferByteSize; } - } - else { + } else { descriptor.pixel_format = static_cast(DirectXPixelFormat::B8G8R8X8UIntNormalized); descriptor.luid = spDevice->GetD3DDevice()->GetAdapterLuid(); // Converted image on GPU auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); - auto pProvider = spSession->GetExecutionProvider(); - auto d3dResource = m_adapter->GetD3D12ResourceFromAllocation(pProvider, pAllocatedResource); + auto d3dResource = reinterpret_cast(void_resource.get()); for (uint32_t batchIdx = 0; batchIdx < m_batchSize; ++batchIdx) { auto videoFrame = m_videoFrames.GetAt(batchIdx); @@ -614,9 +500,6 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, Ort spDevice->GetD3DDeviceCache()->SyncD3D12ToCPU(); pooledConverter->Get()->Detensorizer->ResetAllocator(); } -#ifdef DEBUG_IMAGE_TENSOR_RESOURCE - DumpResourceToCPU(d3dResource, spSession, resourceInfo.Metadata.TensorDescriptor); -#endif } // Release any converters back to the pool by nulling out the wrapper. diff --git a/winml/lib/Api/ImageFeatureValue.h b/winml/lib/Api/ImageFeatureValue.h index d826c12e231ba..c135d2fee4a3d 100644 --- a/winml/lib/Api/ImageFeatureValue.h +++ b/winml/lib/Api/ImageFeatureValue.h @@ -32,20 +32,20 @@ struct ImageFeatureValue : ImageFeatureValueT GetInputMetadata(const WinML::BindingContext& context); // ILotusValueProviderPrivate implementation - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator); + STDMETHOD(GetValue) + (WinML::BindingContext& context, WinML::IValue** out); STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder); STDMETHOD(UpdateSourceResourceData) - (WinML::BindingContext& context, OrtValue* ort_value); + (WinML::BindingContext& context, WinML::IValue* value); STDMETHOD(AbiRepresentation) (winrt::Windows::Foundation::IInspectable& abiRepresentation); std::vector Widths() { return m_widths; } std::vector Heights() { return m_heights; } bool IsBatch() { return m_batchSize > 1; } + private: - com_ptr m_adapter; winrt::Windows::Foundation::Collections::IVector m_videoFrames; std::vector m_widths = {}; std::vector m_heights = {}; diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 51c8a02e04e2e..13b685b963ad5 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -10,6 +10,10 @@ #include "SequenceFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" +#include "OnnxruntimeProvider.h" + +#include + namespace winrt::Windows::AI::MachineLearning::implementation { LearningModel::LearningModel( const hstring& path, @@ -22,70 +26,96 @@ LearningModel::LearningModel( const std::string& path, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); - WINML_THROW_IF_FAILED(adapter_->CreateModelProto(path.c_str(), model_proto_.put())); - - Initialize(); + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(engine_factory_->CreateModel(path.c_str(), path.size(), model_.put())); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL +static HRESULT CreateModelFromStream( + WinML::IEngineFactory* engine_factory, + const wss::IRandomAccessStreamReference stream, + WinML::IModel** model) { + auto content = stream.OpenReadAsync().get(); + + wss::Buffer buffer(static_cast(content.Size())); + auto result = content.ReadAsync( + buffer, + buffer.Capacity(), + wss::InputStreamOptions::None) + .get(); + + auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); + WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid."); + + void* data; + WINML_THROW_IF_FAILED_MSG(bytes->Buffer(reinterpret_cast(&data)), "Failed to acquire buffer from model stream."); + + size_t len = static_cast(content.Size()); + WINML_THROW_IF_FAILED(engine_factory->CreateModel(data, len, model)); + + return S_OK; +} + LearningModel::LearningModel( const wss::IRandomAccessStreamReference stream, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); - WINML_THROW_IF_FAILED(adapter_->CreateModelProto( - static_cast(winrt::get_abi(stream)), - model_proto_.put())); - - Initialize(); + WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); + WINML_THROW_IF_FAILED(CreateModelFromStream(engine_factory_.get(), stream, model_.put())); + WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL -void LearningModel::Initialize() { - WINML_THROW_IF_FAILED(adapter_->CreateModelInfo(model_proto_.get(), model_info_.put())); -} - hstring LearningModel::Author() try { - return WinML::Strings::HStringFromUTF8(model_info_->author()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetAuthor(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Name() try { - return WinML::Strings::HStringFromUTF8( - model_info_->name()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetName(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Domain() try { - return WinML::Strings::HStringFromUTF8( - model_info_->domain()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetDomain(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Description() try { - return WinML::Strings::HStringFromUTF8( - model_info_->description()); + const char* out; + size_t len; + WINML_THROW_IF_FAILED(model_info_->GetDescription(&out, &len)); + return WinML::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL int64_t LearningModel::Version() try { - return model_info_->version(); + int64_t version; + WINML_THROW_IF_FAILED(model_info_->GetVersion(&version)); + return version; } WINML_CATCH_ALL wfc::IMapView LearningModel::Metadata() try { - ABI::Windows::Foundation::Collections::IMapView* metadata; + ABI::Windows::Foundation::Collections::IMapView* metadata = nullptr; wfc::IMapView out; WINML_THROW_IF_FAILED(model_info_->GetModelMetadata(&metadata)); winrt::attach_abi(out, metadata); @@ -104,13 +134,14 @@ LearningModel::GetOperatorRegistry() { operator_provider_.as(); IMLOperatorRegistry* registry = nullptr; - WINML_THROW_IF_FAILED(adapter_->GetOperatorRegistry(operator_provider_native.get(), ®istry)); + // Retrieve the "operator abi" registry. + THROW_IF_FAILED(operator_provider_native->GetRegistry(®istry)); return registry; } wfc::IVectorView LearningModel::InputFeatures() try { - ABI::Windows::Foundation::Collections::IVectorView* features; + ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetInputFeatures(&features)); winrt::attach_abi(out, features); @@ -120,7 +151,7 @@ WINML_CATCH_ALL wfc::IVectorView LearningModel::OutputFeatures() try { - ABI::Windows::Foundation::Collections::IVectorView* features; + ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetOutputFeatures(&features)); winrt::attach_abi(out, features); @@ -130,12 +161,12 @@ WINML_CATCH_ALL void LearningModel::Close() try { // close the model - model_proto_ = nullptr; + model_ = nullptr; } WINML_CATCH_ALL bool LearningModel::IsDisposed() { - return model_proto_ == nullptr; + return model_ == nullptr; } wf::IAsyncOperation @@ -196,30 +227,33 @@ LearningModel::LoadFromStream( } WINML_CATCH_ALL -winmla::IModelProto* -LearningModel::DetachModelProto() { - com_ptr detached_model_proto; - if (model_proto_ != nullptr) { - detached_model_proto.attach(model_proto_.detach()); +WinML::IModel* +LearningModel::DetachModel() { + com_ptr detached_model; + if (model_ != nullptr) { + detached_model.attach(model_.detach()); // Close the model since we now own the model proto Close(); } - return detached_model_proto.detach(); + return detached_model.detach(); } -winmla::IModelProto* -LearningModel::CopyModelProto() { - if (model_proto_ == nullptr) { +WinML::IModel* +LearningModel::CloneModel() { + if (model_ == nullptr) { return nullptr; } - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - com_ptr model_proto; - WINML_THROW_IF_FAILED(adapter->CreateModelProto(model_proto_.get(), model_proto.put())); + com_ptr model_copy; + WINML_THROW_IF_FAILED(model_->CloneModel(model_copy.put())); + + return model_copy.detach(); +} - return model_proto.detach(); +WinML::IEngineFactory* +LearningModel::GetEngineFactory() { + return engine_factory_.get(); } } // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index 261dc4b8655fa..e00eb6339824a 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -4,7 +4,12 @@ #pragma once #include "LearningModel.g.h" -#include "core/providers/winml/winml_provider_factory.h" + +namespace Windows::AI::MachineLearning { +struct IEngineFactory; +struct IModel; +struct IModelInfo; +} // namespace Windows::AI::MachineLearning namespace winrt::Windows::AI::MachineLearning::implementation { @@ -93,20 +98,15 @@ struct LearningModel : LearningModelT { /* Non-ABI methods */ bool IsDisposed(); IMLOperatorRegistry* GetOperatorRegistry(); - winmla::IModelProto* DetachModelProto(); - winmla::IModelProto* CopyModelProto(); + WinML::IModel* DetachModel(); + WinML::IModel* CloneModel(); + WinML::IEngineFactory* GetEngineFactory(); private: - void Initialize(); - void LogCreationEvent(bool fromStream = false); - void ModelUseFP16( - winml::ILearningModelFeatureDescriptor descriptor, - bool& use_fp16); + com_ptr engine_factory_; + com_ptr model_; + com_ptr model_info_; - private: - com_ptr adapter_; - com_ptr model_proto_; - com_ptr model_info_; ILearningModelOperatorProvider operator_provider_; }; diff --git a/winml/lib/Api/LearningModelBinding.cpp b/winml/lib/Api/LearningModelBinding.cpp index 8076f2fb008a1..65a592a3fbef3 100644 --- a/winml/lib/Api/LearningModelBinding.cpp +++ b/winml/lib/Api/LearningModelBinding.cpp @@ -17,7 +17,6 @@ namespace winrt::Windows::AI::MachineLearning::implementation { LearningModelBinding::LearningModelBinding( Windows::AI::MachineLearning::LearningModelSession const& session) try : m_session(session) { session.as()->CheckClosed(); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); } WINML_CATCH_ALL @@ -39,10 +38,6 @@ static Windows::AI::MachineLearning::ILearningModelFeatureDescriptor FindValidBi return nullptr; } -LearningModelBinding::~LearningModelBinding() { - Clear(); -} - using NullableBindingPort = std::optional>; static NullableBindingPort FindValidBinding( @@ -63,7 +58,7 @@ void LearningModelBinding::CacheProvider( m_providers[name] = providerInfo; } -std::tuple LearningModelBinding::CreateBinding( +std::tuple, BindingType> LearningModelBinding::CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& inspectable, Windows::Foundation::Collections::IPropertySet const& properties) { @@ -102,10 +97,9 @@ std::tuple LearningModelBind }; // Get the bound tensor - Ort::Value value(nullptr); - Ort::Allocator ort_allocator(adapter_.get(), nullptr); + winrt::com_ptr value; - // Get the native ORT interface for the given bind value + // Get the native interface for the given bind value auto spLotusValueProvider = featureValue.as(); auto spSession = m_session.as(); @@ -126,7 +120,7 @@ std::tuple LearningModelBind if (!isPlaceHolder || shouldAlwaysTensorize) { // If not a placeholder, attempt to get the underlying resource WINML_THROW_IF_FAILED_MSG( - spLotusValueProvider->GetOrtValue(context, value.put(), ort_allocator.put()), + spLotusValueProvider->GetValue(context, value.put()), "The model variable %s failed tensorization.", name.c_str()); } else { @@ -135,13 +129,15 @@ std::tuple LearningModelBind isPlaceHolder && bindingType == BindingType::kInput, "The model variable %s is an input, but has no associated resources to bind.", name.c_str()); + + WINML_THROW_IF_FAILED(spSession->GetEngine()->CreateNullValue(value.put())); } // Hold onto the input output providers so that our memory doesnt get destroyed! auto providerInfo = ProviderInfo{inspectable, spLotusValueProvider, context}; CacheProvider(name, providerInfo); - - return std::make_tuple(name, value.release(), bindingType, ort_allocator.release()); + + return std::make_tuple(name, value, bindingType); } void LearningModelBinding::Bind( @@ -157,26 +153,17 @@ void LearningModelBinding::Bind( Windows::Foundation::Collections::IPropertySet const& properties) try { _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType bindingType; - std::string bindingName; - OrtValue* binding_value = nullptr; - OrtAllocator* ort_allocator = nullptr; + BindingType binding_type; + std::string binding_name; + winrt::com_ptr binding_value = nullptr; auto featureName = WinML::Strings::UTF8FromHString(name); - std::tie(bindingName, binding_value, bindingType, ort_allocator) = CreateBinding(featureName, value, properties); - Ort::Value ortValue = binding_value ? Ort::Value(binding_value) : Ort::Value(nullptr); - Ort::Allocator ortAllocator(adapter_.get(), ort_allocator); - switch (bindingType) { + std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, value, properties); + switch (binding_type) { case BindingType::kInput: - WINML_THROW_IF_FAILED(BindInput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; case BindingType::kOutput: - WINML_THROW_IF_FAILED(BindOutput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: FAIL_FAST(); @@ -191,8 +178,6 @@ void LearningModelBinding::Clear() try { outputs_.clear(); output_names_.clear(); m_providers.clear(); - input_allocators_.clear(); - output_allocators_.clear(); } WINML_CATCH_ALL @@ -208,14 +193,14 @@ Windows::Foundation::Collections::IIterator } Windows::Foundation::IInspectable LearningModelBinding::Lookup(hstring const& key) { - auto utf8Name = WinML::Strings::UTF8FromHString(key); + auto utf8_name = WinML::Strings::UTF8FromHString(key); - auto foundIt = m_providers.find(utf8Name); + auto foundIt = m_providers.find(utf8_name); WINML_THROW_HR_IF_FALSE_MSG( E_BOUNDS, foundIt != std::end(m_providers), "The binding collection does not contain a variable with name %s.", - utf8Name.c_str()); + utf8_name.c_str()); auto providerInfo = foundIt->second; return providerInfo.CallerSpecifiedFeatureValue; @@ -226,8 +211,8 @@ uint32_t LearningModelBinding::Size() { } bool LearningModelBinding::HasKey(hstring const& key) { - auto utf8Name = WinML::Strings::UTF8FromHString(key); - return m_providers.find(utf8Name) != m_providers.end(); + auto utf8_name = WinML::Strings::UTF8FromHString(key); + return m_providers.find(utf8_name) != m_providers.end(); } void LearningModelBinding::Split( @@ -239,169 +224,110 @@ void LearningModelBinding::Split( second = nullptr; } -ONNXTensorElementDataType STDMETHODCALLTYPE GetONNXTensorElementDataType(winml::TensorKind kind) { - if (kind == TensorKind::Float) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (kind == TensorKind::Double) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (kind == TensorKind::String) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - } else if (kind == TensorKind::UInt8) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - } else if (kind == TensorKind::Int8) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; - } else if (kind == TensorKind::UInt16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; - } else if (kind == TensorKind::Int16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; - } else if (kind == TensorKind::UInt32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; - } else if (kind == TensorKind::Int32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - } else if (kind == TensorKind::UInt64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; - } else if (kind == TensorKind::Int64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } else if (kind == TensorKind::Boolean) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; - } else if (kind == TensorKind::Float16) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - } - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; -} - -bool LearningModelBinding::IsOfMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind) { - if (ort_value.GetTypeInfo().GetONNXType() != ONNX_TYPE_MAP) - return false; - - ONNXTensorElementDataType onnx_key_type; - ONNXTensorElementDataType onnx_value_type; - - WINML_THROW_IF_FAILED(adapter_->GetMapType(ort_value, &onnx_key_type, &onnx_value_type)); - - if (onnx_key_type != GetONNXTensorElementDataType(key_kind)) - return false; - - if (onnx_value_type != GetONNXTensorElementDataType(value_kind)) - return false; - - return true; -}; - -bool LearningModelBinding::IsOfVectorMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind) { - if (ort_value.GetTypeInfo().GetONNXType() != ONNX_TYPE_SEQUENCE) - return false; - - ONNXTensorElementDataType onnx_key_type; - ONNXTensorElementDataType onnx_value_type; - - WINML_THROW_IF_FAILED(adapter_->GetVectorMapType(ort_value, &onnx_key_type, &onnx_value_type)); - - if (onnx_key_type != GetONNXTensorElementDataType(key_kind)) - return false; - - if (onnx_value_type != GetONNXTensorElementDataType(value_kind)) - return false; - - return true; -}; - -bool LearningModelBinding::IsOfTensorType(const Ort::Value& ort_value, TensorKind kind) { - return ort_value.GetTensorTypeAndShapeInfo().GetElementType() == GetONNXTensorElementDataType(kind); -}; - ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue( - const Ort::Value& ort_value, + const winrt::com_ptr value, ILearningModelFeatureDescriptor& descriptor) { - if (ort_value.IsTensor()) { - if (IsOfTensorType(ort_value, TensorKind::Float)) { + bool out; + if (SUCCEEDED(value->IsTensor(&out)) && out) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float, &out)) && out) { if (descriptor.Kind() == LearningModelFeatureKind::Image) { using namespace Windows::Graphics::Imaging; // TODO: this format for unbound output needs more discussion BitmapPixelFormat format = descriptor.as()->BitmapPixelFormat(); - uint32_t width = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[3]); - uint32_t height = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[2]); - uint32_t batchSize = static_cast(ort_value.GetTensorTypeAndShapeInfo().GetShape()[0]); + std::vector shape; + value->GetTensorShape(shape); + uint32_t width = static_cast(shape[3]); + uint32_t height = static_cast(shape[2]); + uint32_t batchSize = static_cast(shape[0]); return implementation::ImageFeatureValue::Create(batchSize, format, width, height); } else { return implementation::TensorFloat::Create(); } } - if (IsOfTensorType(ort_value, TensorKind::Double)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Double, &out)) && out) { return implementation::TensorDouble::Create(); } - if (IsOfTensorType(ort_value, TensorKind::String)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::String, &out)) && out) { return implementation::TensorString::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt8)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt8, &out)) && out) { return implementation::TensorUInt8Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int8)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int8, &out)) && out) { return implementation::TensorInt8Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt16, &out)) && out) { return implementation::TensorUInt16Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int16, &out)) && out) { return implementation::TensorInt16Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt32)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt32, &out)) && out) { return implementation::TensorUInt32Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int32)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int32, &out)) && out) { return implementation::TensorInt32Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::UInt64)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::UInt64, &out)) && out) { return implementation::TensorUInt64Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Int64)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Int64, &out)) && out) { return implementation::TensorInt64Bit::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Boolean)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Boolean, &out)) && out) { return implementation::TensorBoolean::Create(); } - if (IsOfTensorType(ort_value, TensorKind::Float16)) { + if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float16, &out)) && out) { return implementation::TensorFloat16Bit::Create(); } } + // Maps - else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::String)) { + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::String, &out)) && out) { return implementation::MapStringToString::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Int64)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Int64, &out)) && out) { return implementation::MapStringToInt64Bit::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Float, &out)) && out) { return implementation::MapStringToFloat::Create(); - } else if (IsOfMapType(ort_value, TensorKind::String, TensorKind::Double)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::String, TensorKind::Double, &out)) && out) { return implementation::MapStringToDouble::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::String)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::String, &out)) && out) { return implementation::MapInt64BitToString::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Int64)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Int64, &out)) && out) { return implementation::MapInt64BitToInt64Bit::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) { return implementation::MapInt64BitToFloat::Create(); - } else if (IsOfMapType(ort_value, TensorKind::Int64, TensorKind::Double)) { + } + if (SUCCEEDED(value->IsOfMapType(TensorKind::Int64, TensorKind::Double, &out)) && out) { return implementation::MapInt64BitToDouble::Create(); } // Sequences - else if (IsOfVectorMapType(ort_value, TensorKind::String, TensorKind::Float)) { + if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::String, TensorKind::Float, &out)) && out) { return implementation::SequenceMapStringFloat::Create(); - } else if (IsOfVectorMapType(ort_value, TensorKind::Int64, TensorKind::Float)) { + } + if (SUCCEEDED(value->IsOfVectorMapType(TensorKind::Int64, TensorKind::Float, &out)) && out) { return implementation::SequenceMapInt64BitFloat::Create(); } - auto utf8Name = WinML::Strings::UTF8FromHString(descriptor.Name()); + auto utf8_name = WinML::Strings::UTF8FromHString(descriptor.Name()); WINML_THROW_HR_IF_TRUE_MSG( E_UNEXPECTED, true, "The engine produced an unexpected evaluation output for unbound output variable %s.", - utf8Name.c_str()); + utf8_name.c_str()); return nullptr; } Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( const std::string& name, - Ort::Value& ort_value) { + winrt::com_ptr value) { // Find valid binding port auto bindingPort = FindValidBinding( m_session.Model(), @@ -432,12 +358,12 @@ Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( }; // Create empty feature value - auto featureValue = CreateUnboundOuputFeatureValue(ort_value, descriptor); + auto featureValue = CreateUnboundOuputFeatureValue(value, descriptor); // Update feature value auto spLotusValueProvider = featureValue.as(); WINML_THROW_IF_FAILED_MSG( - spLotusValueProvider->UpdateSourceResourceData(context, ort_value), + spLotusValueProvider->UpdateSourceResourceData(context, value.get()), "Failed to update bound object for model variable output %s", name.c_str()); @@ -454,33 +380,30 @@ Windows::Foundation::IInspectable LearningModelBinding::CreateUnboundOutput( std::unordered_map LearningModelBinding::UpdateProviders() { std::unordered_map outputs; - auto& outputNames = GetOutputNames(); - auto& outputMLValues = GetOutputs(); + auto& output_names = GetOutputNames(); + auto& output_values = GetOutputs(); WINML_THROW_HR_IF_FALSE_MSG( E_UNEXPECTED, - outputNames.size() == outputMLValues.size(), + output_names.size() == output_values.size(), "Evaluation produced unexpected output variables."); - for (unsigned i = 0; i < outputNames.size(); i++) { - auto utf8Name = outputNames[i]; - OrtValue* mlValue = outputMLValues[i]; + for (unsigned i = 0; i < output_names.size(); i++) { + auto utf8_name = output_names[i]; + auto value = output_values[i]; - if (m_providers.find(utf8Name) != std::end(m_providers)) { - auto& providerInfo = m_providers[utf8Name]; + if (m_providers.find(utf8_name) != std::end(m_providers)) { + auto& providerInfo = m_providers[utf8_name]; auto provider = providerInfo.Provider; auto context = providerInfo.Context; WINML_THROW_IF_FAILED_MSG( - provider->UpdateSourceResourceData(context, mlValue), + provider->UpdateSourceResourceData(context, value.get()), "Failed to update bound object for model variable output %s", - utf8Name.c_str()); + utf8_name.c_str()); - outputs[utf8Name] = providerInfo.CallerSpecifiedFeatureValue; + outputs[utf8_name] = providerInfo.CallerSpecifiedFeatureValue; } else { // unbound outputs - Ort::Value ort_value(mlValue); - outputs[utf8Name] = CreateUnboundOutput(utf8Name, ort_value); - // this was a weak ref, don't let it deref() - ort_value.release(); + outputs[utf8_name] = CreateUnboundOutput(utf8_name, value); } } @@ -501,31 +424,23 @@ STDMETHODIMP LearningModelBinding::Bind( IUnknown* value) { try { _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType bindingType; - std::string bindingName; - OrtValue* binding_value_ptr = nullptr; - OrtAllocator* ort_allocator = nullptr; + BindingType binding_type; + std::string binding_name; + winrt::com_ptr binding_value; + winrt::Windows::Foundation::IInspectable to; RETURN_IF_FAILED(value->QueryInterface( winrt::guid_of(), reinterpret_cast(winrt::put_abi(to)))); auto featureName = WinML::Strings::UTF8FromUnicode(name, cchName); - std::tie(bindingName, binding_value_ptr, bindingType, ort_allocator) = CreateBinding(featureName, to, nullptr); - Ort::Value ortValue = binding_value_ptr ? Ort::Value(binding_value_ptr) : Ort::Value(nullptr); - Ort::Allocator ortAllocator(adapter_.get(), ort_allocator); - switch (bindingType) { + std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, to, nullptr); + switch (binding_type) { case BindingType::kInput: - WINML_THROW_IF_FAILED(BindInput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; case BindingType::kOutput: - WINML_THROW_IF_FAILED(BindOutput( - bindingName, - std::move(ortValue), - std::move(ortAllocator))); + WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: FAIL_FAST(); @@ -544,43 +459,37 @@ static std::pair Contains(const std::vector& names, c } // This method releases control of memory of ml_value from caller of BindInput -HRESULT LearningModelBinding::BindInput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator) { - auto rc = Contains(input_names_, name); +HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr value) { + bool exists; + size_t index; + std::tie(exists, index) = Contains(input_names_, name); - auto add_or_replace = [this, &name](const bool exists, size_t index, Ort::Value&& value, Ort::Allocator&& ort_allocator) { - if (exists) { - inputs_[index] = std::move(value); - input_allocators_[index] = std::move(ort_allocator); - } else { - input_names_.push_back(name); - inputs_.push_back(std::move(value)); - input_allocators_.push_back(std::move(ort_allocator)); - } - }; - if (ml_value.IsTensor()) { - OrtValue* new_mlvalue; - WINML_THROW_IF_FAILED(m_session.as() - ->GetIInferenceSession() - ->CopyOneInputAcrossDevices(name.c_str(), ml_value, &new_mlvalue)); - add_or_replace(rc.first, rc.second, Ort::Value(new_mlvalue), std::move(ort_allocator)); + auto engine = m_session.as()->GetEngine(); + winrt::com_ptr device_value; + WINML_THROW_IF_FAILED(engine->CreateOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch + + if (exists) { + inputs_[index] = device_value; } else { - add_or_replace(rc.first, rc.second, Ort::Value(ml_value.release()), std::move(ort_allocator)); + input_names_.push_back(name); + inputs_.push_back(device_value); } + return S_OK; } -// This method releases control of memory of ml_value from caller of BindInput -HRESULT LearningModelBinding::BindOutput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator) { - auto rc = Contains(output_names_, name); - if (rc.first) { - outputs_[rc.second] = std::move(ml_value); - output_allocators_[rc.second] = std::move(ort_allocator); +HRESULT LearningModelBinding::BindOutput(const std::string& name, winrt::com_ptr value) { + bool exists; + size_t index; + std::tie(exists, index) = Contains(output_names_, name); + + if (exists) { + outputs_[index] = value; return S_OK; } output_names_.push_back(name); - outputs_.push_back(std::move(ml_value)); - output_allocators_.push_back(std::move(ort_allocator)); + outputs_.push_back(value); return S_OK; } @@ -588,13 +497,17 @@ const std::vector& LearningModelBinding::GetOutputNames() const { return output_names_; } -std::vector& LearningModelBinding::GetOutputs() { return outputs_; } - const std::vector& LearningModelBinding::GetInputNames() const { return input_names_; } -const std::vector& LearningModelBinding::GetInputs() const { return inputs_; } +std::vector>& LearningModelBinding::GetOutputs() { + return outputs_; +} + +const std::vector>& LearningModelBinding::GetInputs() const { + return inputs_; +} void LearningModelBinding::BindUnboundOutputs() { auto& bound_output_names = GetOutputNames(); @@ -634,7 +547,11 @@ void LearningModelBinding::BindUnboundOutputs() { // Add all unbound outputs to binding collection for (const auto& unbound_output : unbound_output_names) { - WINML_THROW_IF_FAILED(BindOutput(unbound_output, Ort::Value(nullptr), Ort::Allocator())); + auto engine = m_session.as()->GetEngine(); + + winrt::com_ptr value; + WINML_THROW_IF_FAILED(engine->CreateNullValue(value.put())); + WINML_THROW_IF_FAILED(BindOutput(unbound_output, value)); } } diff --git a/winml/lib/Api/LearningModelBinding.h b/winml/lib/Api/LearningModelBinding.h index 0d2efc2339c6b..4dd2734bc0710 100644 --- a/winml/lib/Api/LearningModelBinding.h +++ b/winml/lib/Api/LearningModelBinding.h @@ -22,11 +22,12 @@ struct LearningModelBinding : LearningModelBindingT; LearningModelBinding() = delete; - ~LearningModelBinding(); LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); + STDMETHOD(Bind)(const wchar_t* name, UINT32 cchName, IUnknown* value); + void Clear(); Windows::Foundation::Collections::IIterator First(); Windows::Foundation::IInspectable Lookup(hstring const& key); @@ -36,7 +37,7 @@ struct LearningModelBinding : LearningModelBindingT& first, Windows::Foundation::Collections::IMapView& second); - std::tuple CreateBinding( + std::tuple, WinML::BindingType> CreateBinding( const std::string& name, const Windows::Foundation::IInspectable& value, Windows::Foundation::Collections::IPropertySet const& properties); @@ -45,42 +46,32 @@ struct LearningModelBinding : LearningModelBindingT& LearningModelBinding::GetOutputNames() const; - std::vector& LearningModelBinding::GetOutputs(); - const std::vector& LearningModelBinding::GetInputNames() const; - const std::vector& LearningModelBinding::GetInputs() const; - HRESULT BindOutput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator); + const std::vector& GetInputNames() const; + const std::vector& GetOutputNames() const; + + const std::vector>& GetInputs() const; + std::vector>& GetOutputs(); + + HRESULT BindOutput(const std::string& name, winrt::com_ptr value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); - Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, Ort::Value& ort_value); + Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( - const Ort::Value& ort_value, + const winrt::com_ptr value, ILearningModelFeatureDescriptor& descriptor); - bool IsOfTensorType(const Ort::Value& ort_value, TensorKind kind); - bool IsOfMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); - bool IsOfVectorMapType(const Ort::Value& ort_value, TensorKind key_kind, TensorKind value_kind); - HRESULT BindInput(const std::string& name, Ort::Value&& ml_value, Ort::Allocator&& ort_allocator); + HRESULT BindInput(const std::string& name, winrt::com_ptr value); private: const Windows::AI::MachineLearning::LearningModelSession m_session; std::unordered_map m_providers; - com_ptr adapter_; std::vector input_names_; - std::vector inputs_; - std::vector input_allocators_; + std::vector> inputs_; std::vector output_names_; - std::vector outputs_; - std::vector output_allocators_; + std::vector> outputs_; }; } // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 8f77242b5b4e3..013dca5b863ca 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -28,41 +28,42 @@ static const GUID WINML_PIX_EVAL_CAPTURABLE_WORK_GUID = __uuidof(guid_details::W namespace winrt::Windows::AI::MachineLearning::implementation { LearningModelSession::LearningModelSession( - winml::LearningModel const& model) try : LearningModelSession(model, - make(LearningModelDeviceKind::Default)) {} + winml::LearningModel const& model) try : LearningModelSession(model, + make(LearningModelDeviceKind::Default)) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, - deviceToRunOn, - nullptr) {} + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn) try : LearningModelSession(model, + deviceToRunOn, + nullptr) {} WINML_CATCH_ALL LearningModelSession::LearningModelSession( - winml::LearningModel const& model, - winml::LearningModelDevice const& deviceToRunOn, - winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), - device_(deviceToRunOn), - session_options_(learningModelSessionOptions) { + winml::LearningModel const& model, + winml::LearningModelDevice const& deviceToRunOn, + winml::LearningModelSessionOptions const& learningModelSessionOptions) try : model_(model), + device_(deviceToRunOn), + session_options_(learningModelSessionOptions), + operator_registry_(nullptr, nullptr) { Initialize(); } WINML_CATCH_ALL -winmla::IModelProto* +WinML::IModel* LearningModelSession::GetOptimizedModel() { // Get the model proto auto should_close_model = - session_options_ != nullptr && - session_options_.CloseModelOnSessionCreation(); + session_options_ != nullptr && + session_options_.CloseModelOnSessionCreation(); return GetOptimizedModel(should_close_model); } -winmla::IModelProto* +WinML::IModel* LearningModelSession::GetOptimizedModel(bool should_close_model) { - com_ptr model_proto; + com_ptr model; { // Lock the model detach/copy since multiple threads can access concurrently @@ -70,77 +71,70 @@ LearningModelSession::GetOptimizedModel(bool should_close_model) { // Throw if the model has been disposed and is not capable of creating // new sessions. - auto model = model_.as(); - WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, model->IsDisposed(), + auto model_impl = model_.as(); + WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, model_impl->IsDisposed(), "The model has been disposed."); - model_proto.attach(should_close_model - ? model->DetachModelProto() - : model->CopyModelProto()); + model.attach(should_close_model + ? model_impl->DetachModel() + : model_impl->CloneModel()); } // Ensure that the model is runnable on the device - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->EnsureModelDeviceCompatibility(model_, model_proto.get(), device_.as()->GetD3DDeviceCache()->IsFloat16Supported())); - - return model_proto.detach(); + auto isFloat16Supported = device_.as()->GetD3DDeviceCache()->IsFloat16Supported(); + if (!isFloat16Supported) { + WINML_THROW_IF_FAILED(model->ModelEnsureNoFloat16()); + } + return model.detach(); } void LearningModelSession::Initialize() { // Begin recording session creation telemetry _winmlt::TelemetryEvent session_creation_event( - _winmlt::EventCategory::kSessionCreation); + _winmlt::EventCategory::kSessionCreation); // Get the optimized model proto from the learning model - com_ptr model_proto; - model_proto.attach(GetOptimizedModel()); + com_ptr model; + model.attach(GetOptimizedModel()); // Create the session builder auto device_impl = device_.as(); + auto model_impl = model_.as(); - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); + engine_factory_.copy_from(model_impl->GetEngineFactory()); - com_ptr session_builder; - WINML_THROW_IF_FAILED(adapter->CreateOrtSessionBuilder( - device_impl->GetD3DDevice(), - device_impl->GetDeviceQueue(), - session_builder.put())); + com_ptr engine_builder; + engine_factory_->CreateEngineBuilder(engine_builder.put()); - Ort::SessionOptions options(nullptr); - WINML_THROW_IF_FAILED(session_builder->CreateSessionOptions(options.put())); + if (device_impl->IsCpuDevice() == false) { + engine_builder->SetD3D12Resources(device_impl->GetD3DDevice(), device_impl->GetDeviceQueue()); + } // Make onnxruntime apply the batch size override, if any - if (session_options_ && session_options_.BatchSizeOverride() != 0) - { - Ort::ThrowOnError(Ort::GetApi().AddFreeDimensionOverride( - options, - onnx::DATA_BATCH, - session_options_.BatchSizeOverride())); + if (session_options_ && session_options_.BatchSizeOverride() != 0) { + engine_builder->SetBatchSizeOverride(session_options_.BatchSizeOverride()); } - com_ptr session; - WINML_THROW_IF_FAILED(session_builder->CreateSession( - options, session.put(), &cached_execution_provider_)); + com_ptr engine; + WINML_THROW_IF_FAILED(engine_builder->CreateEngine(engine.put())); // Register the custom operator registry - auto model = model_.as(); - operatorRegistry_.reset(model->GetOperatorRegistry()); - WINML_THROW_IF_FAILED(session->RegisterCustomRegistry(operatorRegistry_.get())); + operator_registry_ = MLOperatorRegistry(model_impl->GetOperatorRegistry(), [](auto registry) { registry->Release(); }); + WINML_THROW_IF_FAILED(engine->RegisterCustomRegistry(operator_registry_.get())); - // Register only the transformers not already in ORT - session->RegisterGraphTransformers(); + // Register transformers - this should probably not be exposed on IEngine, but an internal call as this configuration step is ort specific. + engine->RegisterGraphTransformers(); // Load the model into the session - WINML_THROW_IF_FAILED(session->LoadModel(model_proto.get())); + WINML_THROW_IF_FAILED(engine->LoadModel(model.get())); + // the session owns the model_proto now, it used detach() - model_proto = nullptr; + model = nullptr; // Initialize the session - WINML_THROW_IF_FAILED(session_builder->Initialize(session.get(), cached_execution_provider_)); + WINML_THROW_IF_FAILED(engine->Initialize()); // Cache the constructed session - inference_session_ = session; + engine_ = engine; } wfc::IPropertySet @@ -165,8 +159,8 @@ LearningModelSession::Device() try { WINML_CATCH_ALL auto CreateBinding( - LearningModelSession& session, - wfc::IMap const features) { + LearningModelSession& session, + wfc::IMap const features) { auto binding = winrt::make(session); for (auto feature : features.GetView()) { @@ -177,8 +171,8 @@ auto CreateBinding( winml::LearningModelEvaluationResult LearningModelSession::EvaluateFeatures( - wfc::IMap const features, - hstring const correlation_id) try { + wfc::IMap const features, + hstring const correlation_id) try { auto binding = CreateBinding(*this, features); return Evaluate(binding, correlation_id); } @@ -186,65 +180,63 @@ WINML_CATCH_ALL wf::IAsyncOperation LearningModelSession::EvaluateFeaturesAsync( - wfc::IMap const features, - hstring const correlation_id) { + wfc::IMap const features, + hstring const correlation_id) { auto binding = CreateBinding(*this, features); return EvaluateAsync(binding, correlation_id); } -// copied from onnxruntime_cxx_inline.h -inline OrtStatus* OrtRun( - OrtSession * session, - const Ort::RunOptions& run_options, - const char* const* input_names, - const Ort::Value* input_values, - size_t input_count, - const char* const* output_names, - Ort::Value* output_values, - size_t output_count) { - static_assert(sizeof(Ort::Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(const_cast(input_values)); - auto ort_output_values = reinterpret_cast(output_values); - return Ort::GetApi().Run(session, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values); -} - -uint64_t -LearningModelSession::Run( - winrt::com_ptr binding_impl) { +uint64_t LearningModelSession::Run(winrt::com_ptr binding_impl) { CheckClosed(); + auto device = device_.as(); CWinMLAutoLock lock(!device->IsCpuDevice() ? &evaluate_lock_ : nullptr); - // TODO : set the run_options - Ort::RunOptions run_options; + binding_impl->BindUnboundOutputs(); - std::vector inputNames_c; - for (int i=0; i < binding_impl->GetInputNames().size(); i++) - { - inputNames_c.push_back(binding_impl->GetInputNames()[i].c_str()); - } - std::vector outputNames_c; - for (int i = 0; i < binding_impl->GetOutputNames().size(); i++) { - outputNames_c.push_back(binding_impl->GetOutputNames()[i].c_str()); - } - OrtSession* session = nullptr; - - WINML_THROW_IF_FAILED(inference_session_->GetOrtSession(&session)); - // Invoke run on the ORT session. - Ort::ThrowOnError(OrtRun( - session, - run_options, - inputNames_c.data(), - binding_impl->GetInputs().data(), - binding_impl->GetInputs().size(), - outputNames_c.data(), - binding_impl->GetOutputs().data(), - binding_impl->GetOutputs().size())); + auto& input_names = binding_impl->GetInputNames(); + std::vector input_names_raw; + std::transform( + std::begin(input_names), + std::end(input_names), + std::back_inserter(input_names_raw), + [&](auto& name) { return name.c_str(); }); + + auto& inputs = binding_impl->GetInputs(); + std::vector inputs_raw; + std::transform( + std::begin(inputs), + std::end(inputs), + std::back_inserter(inputs_raw), + [&](auto& input) { return input.get(); }); + + auto& output_names = binding_impl->GetOutputNames(); + std::vector output_names_raw; + std::transform( + std::begin(output_names), + std::end(output_names), + std::back_inserter(output_names_raw), + [&](auto& name) { return name.c_str(); }); + + auto outputs = binding_impl->GetOutputs(); + std::vector outputs_raw; + std::transform( + std::begin(outputs), + std::end(outputs), + std::back_inserter(outputs_raw), + [&](auto& input) { return input.get(); }); + + engine_->Run(input_names_raw.data(), + inputs_raw.data(), + input_names_raw.size(), + output_names_raw.data(), + outputs_raw.data(), + output_names_raw.size()); if (!device->IsCpuDevice()) { // Flush the D3D12 work from the DML execution provider and queue a fence before we release the lock. // This allows us to wait without holding onto the lock in GetResults. - inference_session_->FlushContext(GetExecutionProvider()); + engine_->FlushContext(); return device->GetD3DDeviceCache()->QueueFenceToD3D12(); } @@ -254,9 +246,9 @@ LearningModelSession::Run( winml::LearningModelEvaluationResult LearningModelSession::GetResults( - winrt::com_ptr binding_impl, - hstring const& correlation_id, - uint64_t evaluation_complete_fence) { + winrt::com_ptr binding_impl, + hstring const& correlation_id, + uint64_t evaluation_complete_fence) { // First wait on the fence value for the expected frame. This is passed in so that // the fence value is added to the queue in a thread safe manor. auto device = device_.as(); @@ -271,10 +263,10 @@ LearningModelSession::GetResults( if (is_gpu_evaluation) { // For DML we aren't using the Sync function because we want to make fencing the // completed frame thread safe while not holding the lock while waiting for the gpu. - inference_session_->ReleaseCompletedReferences(GetExecutionProvider()); + engine_->ReleaseCompletedReferences(); } else { // For CPU call the standard Sync function - GetExecutionProvider()->Sync(); + engine_->Sync(); } // This isn't the best we are holding the lock while we wait for detensorize on the GPU. @@ -286,7 +278,7 @@ LearningModelSession::GetResults( // to avoid requiring the extra allocation during each evaluation. if (is_first_evaluate_) { if (is_gpu_evaluation) { - inference_session_->TrimUploadHeap(GetExecutionProvider()); + engine_->TrimUploadHeap(); } is_first_evaluate_ = false; } @@ -309,7 +301,7 @@ LearningModelSession::EvaluateAsync( _winmlt::TelemetryEvent kEvaluateModel_event(_winmlt::EventCategory::kEvaluation); auto device = device_.as(); - // Get the ORT binding collection + // Get the binding collection auto binding_impl = binding.as(); ApplyEvaluationProperties(); @@ -369,7 +361,7 @@ LearningModelSession::Evaluate( capture_interface->BeginCapturableWork(WINML_PIX_EVAL_CAPTURABLE_WORK_GUID); } - // Get the ORT binding collection + // Get the binding collection auto binding_impl = binding.as(); uint64_t evaluation_complete_fence = Run(binding_impl); @@ -383,16 +375,14 @@ LearningModelSession::Evaluate( WINML_CATCH_ALL void LearningModelSession::Close() { - inference_session_ = nullptr; + engine_ = nullptr; } void LearningModelSession::ApplyEvaluationProperties() try { if (evaluation_properties_) { auto is_debug_output_enabled = evaluation_properties_.HasKey(c_enable_debug_output); if (is_debug_output_enabled) { - com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - adapter->EnableDebugOutput(); + engine_factory_->EnableDebugOutput(is_debug_output_enabled); } } } @@ -407,24 +397,19 @@ void LearningModelSession::ToggleProfiler() { WINML_PROVIDER_KEYWORD_LOTUS_PROFILING); if (is_provider_enabled) { - inference_session_->StartProfiling(); + engine_->StartProfiling(); } else { - inference_session_->EndProfiling(); + engine_->EndProfiling(); } } -onnxruntime::IExecutionProvider* -LearningModelSession::GetExecutionProvider() { - return cached_execution_provider_; -} - -winmla::IInferenceSession* -LearningModelSession::GetIInferenceSession() { - return inference_session_.get(); +WinML::IEngine* +LearningModelSession::GetEngine() { + return engine_.get(); } void LearningModelSession::CheckClosed() { - if (!inference_session_) { + if (!engine_) { WINML_THROW_HR(RO_E_CLOSED); } } diff --git a/winml/lib/Api/LearningModelSession.h b/winml/lib/Api/LearningModelSession.h index 8c0acf51171cc..bdb1dd2fb0d03 100644 --- a/winml/lib/Api/LearningModelSession.h +++ b/winml/lib/Api/LearningModelSession.h @@ -9,6 +9,7 @@ #include "MLOperatorAuthor.h" #include "WinML_Lock.h" #include "core/providers/winml/winml_provider_factory.h" +#include "iengine.h" namespace winrt::Windows::AI::MachineLearning::implementation { @@ -66,11 +67,9 @@ struct LearningModelSession : LearningModelSessionT { public: /* Non-ABI methods */ - onnxruntime::IExecutionProvider* - GetExecutionProvider(); - winmla::IInferenceSession* - GetIInferenceSession(); + WinML::IEngine* + GetEngine(); void CheckClosed(); @@ -79,10 +78,10 @@ struct LearningModelSession : LearningModelSessionT { void Initialize(); - winmla::IModelProto* + WinML::IModel* GetOptimizedModel(); - winmla::IModelProto* + WinML::IModel* GetOptimizedModel(bool should_close_model); uint64_t @@ -102,16 +101,11 @@ struct LearningModelSession : LearningModelSessionT { ToggleProfiler(); private: - com_ptr inference_session_; - struct IMLOperatorRegistryDeleter { - void operator()(IMLOperatorRegistry* p) { - p->Release(); - } - }; - std::unique_ptr operatorRegistry_; - - // reference to the active execution provider. weak - onnxruntime::IExecutionProvider* cached_execution_provider_ = nullptr; + com_ptr engine_factory_; + com_ptr engine_; + + using MLOperatorRegistry = std::unique_ptr; + MLOperatorRegistry operator_registry_; winml::LearningModel model_; winml::LearningModelDevice device_; diff --git a/winml/lib/Api/MapFeatureDescriptor.cpp b/winml/lib/Api/MapFeatureDescriptor.cpp index d30734c3be065..60f63c13f85d1 100644 --- a/winml/lib/Api/MapFeatureDescriptor.cpp +++ b/winml/lib/Api/MapFeatureDescriptor.cpp @@ -18,19 +18,6 @@ MapFeatureDescriptor::MapFeatureDescriptor( value_kind_(value_kind) { } -MapFeatureDescriptor::MapFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& KeyKind, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ValueDescriptor) : - name_(Name), - description_(Description), - is_required_(IsRequired), - key_kind_(KeyKind), - value_kind_(ValueDescriptor) { -} - winml::TensorKind MapFeatureDescriptor::KeyKind() try { return key_kind_; diff --git a/winml/lib/Api/MapFeatureDescriptor.h b/winml/lib/Api/MapFeatureDescriptor.h index 1b752b2eb3b63..3641585dd7d87 100644 --- a/winml/lib/Api/MapFeatureDescriptor.h +++ b/winml/lib/Api/MapFeatureDescriptor.h @@ -17,14 +17,7 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< bool is_required, winml::TensorKind keyKind, winml::ILearningModelFeatureDescriptor valueKind); - - MapFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::TensorKind const& KeyKind, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ValueDescriptor); - + // IMapDescriptor winml::TensorKind KeyKind(); @@ -62,10 +55,4 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< winml::TensorKind key_kind_; winml::ILearningModelFeatureDescriptor value_kind_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct MapFeatureDescriptor : MapFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/SequenceFeatureDescriptor.cpp b/winml/lib/Api/SequenceFeatureDescriptor.cpp index 725a66bae253b..0cc1248cc88eb 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.cpp +++ b/winml/lib/Api/SequenceFeatureDescriptor.cpp @@ -16,16 +16,6 @@ SequenceFeatureDescriptor::SequenceFeatureDescriptor( description_(WinML::Strings::HStringFromUTF8(description)), is_required_(is_required), element_descriptor_(descriptor) {} -SequenceFeatureDescriptor::SequenceFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ElementDescriptor) : - name_(Name), - description_(Description), - is_required_(IsRequired), - element_descriptor_(ElementDescriptor) { -} winml::ILearningModelFeatureDescriptor diff --git a/winml/lib/Api/SequenceFeatureDescriptor.h b/winml/lib/Api/SequenceFeatureDescriptor.h index 04e5d392ae261..c45a06ccaba38 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.h +++ b/winml/lib/Api/SequenceFeatureDescriptor.h @@ -15,11 +15,6 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< const char* description, bool is_required, winml::ILearningModelFeatureDescriptor element_descriptor); - SequenceFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - Windows::AI::MachineLearning::ILearningModelFeatureDescriptor const& ElementDescriptor); winml::ILearningModelFeatureDescriptor ElementDescriptor(); @@ -53,10 +48,4 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< bool is_required_; winml::ILearningModelFeatureDescriptor element_descriptor_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/TensorFeatureDescriptor.cpp b/winml/lib/Api/TensorFeatureDescriptor.cpp index e4517f7b3870b..3cf7cc6a36fd9 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.cpp +++ b/winml/lib/Api/TensorFeatureDescriptor.cpp @@ -11,9 +11,9 @@ namespace winrt::Windows::AI::MachineLearning::implementation { TensorFeatureDescriptor::TensorFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, bool has_unsupported_image_metadata) : name_(WinML::Strings::HStringFromUTF8(name)), description_(WinML::Strings::HStringFromUTF8(description)), tensor_kind_(tensor_kind), @@ -22,20 +22,6 @@ TensorFeatureDescriptor::TensorFeatureDescriptor( has_unsupported_image_metadata_(has_unsupported_image_metadata) { } -TensorFeatureDescriptor::TensorFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - winml::TensorKind const& TensorKind, - array_view Shape, - bool HasUnsupportedImageMetadata) : name_(Name), - description_(Description), - tensor_kind_(TensorKind), - shape_(Shape.begin(), Shape.end()), - is_required_(IsRequired), - has_unsupported_image_metadata_(HasUnsupportedImageMetadata) { -} - winml::TensorKind TensorFeatureDescriptor::TensorKind() try { return tensor_kind_; @@ -75,11 +61,6 @@ bool TensorFeatureDescriptor::IsRequired() try { } WINML_CATCH_ALL -bool TensorFeatureDescriptor::HasUnsupportedImageMetadata() try { - return has_unsupported_image_metadata_; -} -WINML_CATCH_ALL - bool TensorFeatureDescriptor::IsUnsupportedMetaData() try { return has_unsupported_image_metadata_; } diff --git a/winml/lib/Api/TensorFeatureDescriptor.h b/winml/lib/Api/TensorFeatureDescriptor.h index 975b359c1f13b..5e54978c5847a 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.h +++ b/winml/lib/Api/TensorFeatureDescriptor.h @@ -13,23 +13,17 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< TensorFeatureDescriptor( const char* name, const char* description, - bool is_required, winml::TensorKind tensor_kind, const std::vector& shape, + bool is_required, bool has_unsuppored_image_metadata); - TensorFeatureDescriptor( - hstring const& Name, - hstring const& Description, - bool IsRequired, - TensorKind const& TensorKind, - array_view Shape, - bool HasUnsupportedImageMetadata); - // ITensorDescriptor - winml::TensorKind TensorKind(); - wfc::IVectorView Shape(); - bool HasUnsupportedImageMetadata(); + winml::TensorKind + TensorKind(); + + wfc::IVectorView + Shape(); // IFeatureDescriptor winrt::hstring @@ -65,10 +59,4 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< bool is_required_; bool has_unsupported_image_metadata_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { - struct TensorFeatureDescriptor : TensorFeatureDescriptorT { - - }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file diff --git a/winml/lib/Api/impl/MapBase.h b/winml/lib/Api/impl/MapBase.h index b3490bbe1a03b..d59366140f69e 100644 --- a/winml/lib/Api/impl/MapBase.h +++ b/winml/lib/Api/impl/MapBase.h @@ -40,53 +40,9 @@ struct MapBase : winrt::implements< std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!"); - template - struct ValidLotusType { using Type = T; }; - template <> - struct ValidLotusType { using Type = std::string; }; - - using LotusKey = typename ValidLotusType::Type; - using LotusValue = typename ValidLotusType::Type; - using LotusMap = std::pair, std::vector>; using ABIMap = ::winrt::Windows::Foundation::Collections::IMap; using ABIMapView = ::winrt::Windows::Foundation::Collections::IMapView; - template - static typename ValidLotusType::Type ConvertToValidLotusType(TRawType raw) { - return raw; - } - - template <> - static typename ValidLotusType::Type ConvertToValidLotusType(winrt::hstring raw) { - return WinML::Strings::UTF8FromHString(raw); - } - - template - static std::vector ConvertToABIType(Ort::Value& ort_value) { - // make sure this is an array of these types - auto shape = ort_value.GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - THROW_HR_IF(E_INVALIDARG, shape.size() != 1); - auto lotus_value = ort_value.GetTensorMutableData::Type>(); - // now go through all the entries - std::vector out; - for (auto i = 0; i < shape[0]; i++) { - out.push_back(lotus_value[i]); - } - // retun the vector - return out; - } - - template <> - static std::vector ConvertToABIType(Ort::Value& ort_value) { - auto strings = ort_value.GetStrings(); - std::vector out; - for (auto i = 0; i < strings.size(); ++i) { - out.push_back(WinML::Strings::HStringFromUTF8(strings[i].c_str())); - } - return out; - } - MapBase(ABIMap const& data) : data_(data) {} static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create() { @@ -129,51 +85,16 @@ struct MapBase : winrt::implements< return S_OK; } - void ConvertToLotusMap(const ABIMap& map) { - std::vector keys; - std::vector values; - for (const auto& pair : map) { - auto key = ConvertToValidLotusType(pair.Key()); - auto value = ConvertToValidLotusType(pair.Value()); - keys.push_back(key); - values.push_back(value); - } - lotus_data_ = std::make_unique(std::make_pair(keys, values)); - } - - template - static onnxruntime::MLDataType GetLotusType(winmla::IWinMLAdapter* adapter) { - return adapter->GetMapType(TensorKindFrom::Type, TensorKindFrom::Type); - } + STDMETHOD(GetValue) + (WinML::BindingContext& context, IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); - template - static Ort::Value CreateOrtMap(TLotusKey* keys, TLotusValue* values, size_t len) { - // now create OrtValue wrappers over the buffers - auto cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector shape = {static_cast(len)}; - auto keys_ort_value = Ort::Value::CreateTensor(cpu_memory, keys, len, shape.data(), shape.size()); - auto values_ort_value = Ort::Value::CreateTensor(cpu_memory, values, len, shape.data(), shape.size()); - // make the map - return Ort::Value::CreateMap(keys_ort_value, values_ort_value); - } - - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); - ORT_UNUSED_PARAMETER(context); - // TODO: Tensorized data should be cached so multiple bindings work more efficiently - - // TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything - if (context.type == WinML::BindingType::kOutput) { - *ort_value = nullptr; - return S_OK; + if (context.type == WinML::BindingType::kInput) { + RETURN_IF_FAILED(engine->CreateMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, out)); + } else { + RETURN_IF_FAILED(engine->CreateNullValue(out)); } - - // handle inputs, create and store a copy of the map - ConvertToLotusMap(data_); - - // and make the map - *ort_value = CreateOrtMap(lotus_data_->first.data(), lotus_data_->second.data(), lotus_data_->first.size()).release(); return S_OK; } @@ -185,51 +106,23 @@ struct MapBase : winrt::implements< } STDMETHOD(UpdateSourceResourceData) - (BindingContext& context, OrtValue* ort_value) { - ORT_UNUSED_PARAMETER(context); + (BindingContext& context, IValue* value) { data_.Clear(); - - Ort::AllocatorWithDefaultOptions allocator; - - // get the keys - OrtValue* ptr = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, 0, allocator, &ptr)); - Ort::Value keys{ptr}; - // get the values - ptr = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, 1, allocator, &ptr)); - Ort::Value values{ptr}; - - auto keys_vector = ConvertToABIType(keys); - auto values_vector = ConvertToABIType(values); - - auto len = keys.GetCount(); - for (auto i = 0; i < len; ++i) { - data_.Insert(keys_vector[i], values_vector[i]); - } - return S_OK; - - // TODO: code this - //const LotusMap& map = *static_cast(pResource); - //for (const auto& pair : map) { - // auto key = ConvertToABIType(pair.first); - // auto value = ConvertToABIType(pair.second); - // data_.Insert(key, value); - //} - + auto session = context.session.as(); + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->FillFromMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, value)); return S_OK; } STDMETHOD(AbiRepresentation) ( - winrt::Windows::Foundation::IInspectable& abiRepresentation) { + winrt::Windows::Foundation::IInspectable& abiRepresentation) { data_.as(abiRepresentation); return S_OK; } private: ABIMap data_; - std::unique_ptr lotus_data_; }; } // namespace Windows::AI::MachineLearning diff --git a/winml/lib/Api/impl/SequenceBase.h b/winml/lib/Api/impl/SequenceBase.h index 560ade3faeae6..d84008ec69ed9 100644 --- a/winml/lib/Api/impl/SequenceBase.h +++ b/winml/lib/Api/impl/SequenceBase.h @@ -22,35 +22,28 @@ struct SequenceBase : public winrt::implements< winml::ILearningModelFeatureValue, WinML::ISequenceFeatureValue, WinML::ILotusValueProviderPrivate> { + using ABISequence = wfc::IIterable; using AbiMapStringToFloat = wfc::IMap; using AbiMapInt64BitToFloat = wfc::IMap; - template - struct ValidLotusType { using Type = T; }; - template <> - struct ValidLotusType { - //using Type = std::map; - using TKey = std::string; - using TValue = float; - using Type = std::pair, std::vector>; - using ABIKey = winrt::hstring; - using ABIValue = TValue; + template struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::Undefined; + static constexpr winml::TensorKind Value = winml::TensorKind::Undefined; + }; + template <> struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::String; + static constexpr winml::TensorKind Value = winml::TensorKind::Float; }; template <> - struct ValidLotusType { - //using Type = std::map; - using TKey = int64_t; - using TValue = float; - using Type = std::pair, std::vector>; - using ABIKey = TKey; - using ABIValue = TValue; + struct SequenceAbiTypeInfo { + static constexpr winml::TensorKind Key = winml::TensorKind::Int64; + static constexpr winml::TensorKind Value = winml::TensorKind::Float; }; template void GetElementDescriptor(winml::ILearningModelFeatureDescriptor* result) { - *result = TensorFeatureDescriptorFrom::CreateAnonymous( - std::vector{1, 1, 1, 1}); + static_assert(false, "Only sequences of of map and map are supported.") } template <> @@ -87,9 +80,6 @@ struct SequenceBase : public winrt::implements< value_descriptor /* value kind */); } - using LotusSequence = std::vector::Type>; - using ABISequence = wfc::IIterable; - SequenceBase(const ABISequence& data) : data_(data) {} static winml::ILearningModelFeatureValue @@ -120,114 +110,22 @@ struct SequenceBase : public winrt::implements< return S_OK; } - template - static - typename ValidLotusType::Type - ConvertToValidLotusType( - TRawType raw) { - return raw; - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - winrt::hstring raw) { - return WinML::Strings::UTF8FromHString(raw); - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - AbiMapStringToFloat raw) { - std::vector::TKey> keys; - std::vector::TValue> values; - for (auto pair : raw) { - auto key = WinML::Strings::UTF8FromHString(pair.Key()); - keys.push_back(key); - values.push_back(pair.Value()); - } - return std::make_pair(keys, values); - } - - template <> - static - typename ValidLotusType::Type - ConvertToValidLotusType( - AbiMapInt64BitToFloat raw) { - std::vector::TKey> keys; - std::vector::TValue> values; - for (const auto& pair : raw) { - keys.push_back(pair.Key()); - values.push_back(pair.Value()); - } - return std::make_pair(keys, values); - } - - void - ConvertToLotusSequence( - const ABISequence& sequence) { - LotusSequence lotus_sequence; - - std::transform( - begin(sequence), - end(sequence), - std::back_inserter(lotus_sequence), - [](const auto& value) { - return ConvertToValidLotusType(value); - }); - - lotus_data_ = std::make_unique(lotus_sequence); - } - - template - static Ort::Value CreateOrtMap(TLotusKey* keys, TLotusValue* values, size_t len) { - // now create OrtValue wrappers over the buffers - auto cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::vector shape = {static_cast(len)}; - auto keys_ort_value = Ort::Value::CreateTensor(cpu_memory, keys, len, shape.data(), shape.size()); - auto values_ort_value = Ort::Value::CreateTensor(cpu_memory, values, len, shape.data(), shape.size()); - // make the map - return Ort::Value::CreateMap(keys_ort_value, values_ort_value); - } - - STDMETHOD(GetOrtValue)( + STDMETHOD(GetValue)( WinML::BindingContext& context, - OrtValue** ort_value, - OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); - // TODO: Tensorized data should be cached so multiple bindings work more efficiently - - // TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything - if (context.type == WinML::BindingType::kOutput) { - *ort_value = nullptr; - return S_OK; + IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); + + if (context.type == WinML::BindingType::kInput) { + // In opset 10, all ops that use sequences are seq. + // In opset 11, we will need to support seq> as well. + RETURN_IF_FAILED(engine->CreateSequenceOfMapsValue( + reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), + SequenceAbiTypeInfo::Key, SequenceAbiTypeInfo::Value, out)); + } else { + RETURN_IF_FAILED(engine->CreateNullValue(out)); } - - // handle inputs, create and store a copy of the sequence - ConvertToLotusSequence(data_); - - // now create OrtValue wrappers over the buffers - std::vector sequence_values; - for (auto it = lotus_data_->begin(); it != lotus_data_->end(); ++it) { - // make a ort value for this map - auto map = *it; - sequence_values.emplace_back(CreateOrtMap(map.first.data(), map.second.data(), map.first.size())); - } - *ort_value = Ort::Value::CreateSequence(sequence_values).release(); return S_OK; - - /* winrt::com_ptr adapter; - RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - auto lotus_type = adapter->GetVectorMapType( - TensorKindFrom::TKey>::Type, - TensorKindFrom::TValue>::Type); - - winrt::com_ptr ml_value_out; - adapter->CreateOrtValue(lotus_data_.get(), lotus_type, ml_value_out.put()); - - *ml_value = ml_value_out.detach();*/ } STDMETHOD(IsPlaceholder) @@ -238,61 +136,16 @@ struct SequenceBase : public winrt::implements< return S_OK; } - template - static std::vector ConvertToABIType(Ort::Value& ort_value) { - // make sure this is an array of these types - auto shape = ort_value.GetTensorTypeAndShapeInfo().GetShape(); - // there needs to be only one dimension - THROW_HR_IF(E_INVALIDARG, shape.size() != 1); - auto lotus_value = ort_value.GetTensorMutableData::Type>(); - // now go through all the entries - std::vector out; - for (auto i = 0; i < shape[0]; i++) { - out.push_back(lotus_value[i]); - } - // return the vector - return out; - } - - template <> - static std::vector ConvertToABIType(Ort::Value& ort_value) { - auto strings = ort_value.GetStrings(); - std::vector out; - for (auto i = 0; i < strings.size(); ++i) { - out.push_back(WinML::Strings::HStringFromUTF8(strings[i].c_str())); - } - return out; - } - STDMETHOD(UpdateSourceResourceData)( BindingContext& context, - OrtValue* ort_value) { - ORT_UNUSED_PARAMETER(context); + IValue* out) { auto writable_vector = data_.as>(); writable_vector.Clear(); - Ort::AllocatorWithDefaultOptions allocator; - size_t len; - Ort::ThrowOnError(Ort::GetApi().GetValueCount(ort_value, &len)); - for (auto i = 0; i < len; ++i) { - OrtValue* out = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetValue(ort_value, i, allocator, &out)); - Ort::Value map{out}; - auto keys = map.GetValue(0, allocator); - auto values = map.GetValue(1, allocator); + auto session = context.session.as(); + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->FillSequenceOfMapsValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), SequenceAbiTypeInfo::Key, SequenceAbiTypeInfo::Value, out)); - auto keys_vector = ConvertToABIType::ABIKey>(keys); - auto values_vector = ConvertToABIType::ABIValue>(values); - - std::map::ABIKey, typename ValidLotusType::ABIValue> std_map; - for (auto j = 0; j < keys_vector.size(); ++j) { - std_map[keys_vector[j]] = values_vector[j]; - } - auto abi_map = winrt::single_threaded_map::ABIKey, typename ValidLotusType::ABIValue>( - std::move(std_map)); - - writable_vector.Append(abi_map); - } return S_OK; } @@ -304,7 +157,6 @@ struct SequenceBase : public winrt::implements< private: ABISequence data_; - std::unique_ptr lotus_data_; }; -} // namespace Windows::AI::MachineLearning +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/Tensor.h b/winml/lib/Api/impl/Tensor.h index 69503cb42b63e..2e0e2b16ee34d 100644 --- a/winml/lib/Api/impl/Tensor.h +++ b/winml/lib/Api/impl/Tensor.h @@ -5,14 +5,6 @@ #include "TensorBuffer.h" -// we further specialize these base types for a couple of extra tensor element types -namespace Ort { -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; }; -template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; -} - // // the Tensor class is the actual object for CPU memory buffers. // TensorBase contains one of these to represent the raw memory @@ -27,13 +19,11 @@ class Tensor { TensorBufferPtr m_buffer; std::vector shape_; - winrt::com_ptr adapter_; public: Tensor() = delete; Tensor( - winmla::IWinMLAdapter* adapter, std::vector const& shape, winrt::Windows::Storage::Streams::IBuffer buffer) : shape_(shape), m_buffer( @@ -45,11 +35,9 @@ class Tensor { static_cast(1), std::multiplies())), buffer)) { - adapter_.copy_from(adapter); } Tensor( - winmla::IWinMLAdapter* adapter, std::vector const& shape) : shape_(shape), m_buffer( TensorBuffer::Create( @@ -59,11 +47,9 @@ class Tensor { std::end(shape), static_cast(1), std::multiplies())))) { - adapter_.copy_from(adapter); } Tensor( - winmla::IWinMLAdapter* adapter, std::vector const&& shape) : shape_(std::move(shape)), m_buffer( TensorBuffer::Create( @@ -73,31 +59,18 @@ class Tensor { std::end(shape), static_cast(1), std::multiplies())))) { - adapter_.copy_from(adapter); } auto size() const { return m_buffer->Size(); } - auto buffer() { - return m_buffer->Buffer(); + auto size_in_bytes() const { + return m_buffer->SizeInBytes(); } - Ort::Value GetValue() { - // this is cpu memory - // TODO: what is the difference between the device allocator and the arena allocator? - Ort::MemoryInfo cpu_memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - // create the OrtValue as a tensor letting ort know that we own the data buffer - auto value = Ort::Value::CreateTensor( - cpu_memory, - buffer().second, - m_buffer->SizeInBytes(), - shape_.data(), - shape_.size()); -// Ort::TypeToTensorType::type); - return value; + auto buffer() { + return m_buffer->Buffer(); } void set(uint32_t size, const T* pData) { @@ -111,5 +84,9 @@ class Tensor { const std::vector& shape() const { return shape_; } + + auto get_tensor_buffer() { + return m_buffer; + } }; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorBase.h b/winml/lib/Api/impl/TensorBase.h index c706ecbd4f914..c79197fe4b9e9 100644 --- a/winml/lib/Api/impl/TensorBase.h +++ b/winml/lib/Api/impl/TensorBase.h @@ -15,6 +15,7 @@ #include "core/session/onnxruntime_c_api.h" namespace Windows::AI::MachineLearning { + // TensorBase // // This is the base class for all data based Tensor types. It exposes array and IVectorView @@ -69,87 +70,78 @@ struct TensorBase : TBase { /// 3) use provided backing gpu memory /// a) TensorBase(std::vector const& shape, ID3D12Resource* pResource) TensorBase() : m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); } TensorBase(winrt::Windows::Foundation::Collections::IIterable const& shape) : shape_(begin(shape), end(shape)), m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } TensorBase(std::vector const& shape) : shape_(shape), m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } - TensorBase(std::vector const& shape, ID3D12Resource* pResource, UINT64 resource_width) : shape_(shape), - m_resources(std::make_shared>()) { - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + TensorBase(std::vector const& shape, ID3D12Resource* resource) : shape_(shape), + m_resources(std::make_shared>()) { // This Api is not supported for TensorString WINML_THROW_HR_IF_TRUE_MSG( E_ILLEGAL_METHOD_CALL, (std::is_same::value), "TensorString objects cannot be created from a ID3D12Resource!"); - GetGpuResource() = std::make_shared(pResource, resource_width); + GetGpuResource().copy_from(resource); } - Ort::Value CreateGPUMLValue(std::shared_ptr& resource, BindingContext& context) { + HRESULT CreateGPUMLValue(ID3D12Resource* resource, BindingContext& context, IValue** out) { THROW_HR_IF_NULL(E_INVALIDARG, resource); - THROW_HR_IF_NULL(E_UNEXPECTED, resource->ExecutionProviderAllocatedResource); - - Ort::MemoryInfo dml_memory(nullptr); - auto session_impl = context.session.as(); - auto provider = session_impl->GetExecutionProvider(); - WINML_THROW_IF_FAILED(adapter_->GetProviderMemoryInfo(provider, dml_memory.put())); - auto spSession = context.session.as(); - auto spDevice = spSession->Device().as(); + auto session = context.session.as(); + auto device = session->Device().as(); WINML_THROW_HR_IF_TRUE_MSG(WINML_ERR_INVALID_BINDING, - spDevice->IsCpuDevice(), + device->IsCpuDevice(), "Cannot create GPU tensor on CPU device"); - // create the OrtValue as a tensor letting ort know that we own the data buffer - auto value = Ort::Value::CreateTensor( - dml_memory, - resource->ExecutionProviderAllocatedResource, - resource->resource_width_, - shape_.data(), - shape_.size(), - Ort::TypeToTensorType::type); - return value; + auto engine = session->GetEngine(); + RETURN_IF_FAILED(engine->CreateTensorValueFromExternalD3DResource(resource, shape_.data(), shape_.size(), TensorKind(), out)); + return S_OK; } - Ort::Value CPUTensorize(WinML::BindingContext& context) { + HRESULT CPUTensorize(WinML::BindingContext& context, IValue** out) { + auto session = context.session.as(); + auto engine = session->GetEngine(); + if (GetCpuResource() != nullptr) { - return GetCpuResource()->GetValue(); + return CreateTensorValueFromExternalBuffer(engine, out); } // If there is no matching cpu resource, then fallback to a gpu resource if (GetGpuResource() != nullptr) { - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } WINML_THROW_HR(WINML_ERR_INVALID_BINDING); } - Ort::Value GPUTensorize(WinML::BindingContext& context) { + HRESULT GPUTensorize(WinML::BindingContext& context, IValue** out) { if (GetGpuResource() != nullptr) { - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } + // Get engine + auto session = context.session.as(); + auto engine = session->GetEngine(); + // If there is no matching gpu resource, then fallback to a cpu resource if (GetCpuResource() != nullptr) { - return GetCpuResource()->GetValue(); + return CreateTensorValueFromExternalBuffer(engine, out); } if (TensorKind() == winrt::Windows::AI::MachineLearning::TensorKind::String) { // Lazily allocate the cpu TensorString resource // TensorStrings are CPU only, and so a gpu resource cannot be allocated for them. - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); - return GetCpuResource()->GetValue(); + GetCpuResource() = std::make_shared>(shape_); + return CreateTensorValueFromExternalBuffer(engine, out); } else { // Try to allocate the backing memory for the caller auto bufferSize = std::accumulate(std::begin(shape_), std::end(shape_), static_cast(1), std::multiplies()); @@ -178,21 +170,21 @@ struct TensorBase : TBase { D3D12_TEXTURE_LAYOUT_ROW_MAJOR, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - auto spSession = context.session.as(); - auto spDevice = spSession->Device().as(); + auto device = session->Device().as(); - winrt::com_ptr pGPUResource = nullptr; - spDevice->GetD3DDevice()->CreateCommittedResource( + winrt::com_ptr gpu_resource = nullptr; + device->GetD3DDevice()->CreateCommittedResource( &heapProperties, D3D12_HEAP_FLAG_NONE, &resourceDesc, D3D12_RESOURCE_STATE_COMMON, nullptr, __uuidof(ID3D12Resource), - pGPUResource.put_void()); + gpu_resource.put_void()); + + GetGpuResource() = gpu_resource; - GetGpuResource() = std::make_shared(pGPUResource.get(), resourceDesc.Width); - return CreateGPUMLValue(GetGpuResource(), context); + return CreateGPUMLValue(GetGpuResource().get(), context, out); } } @@ -207,9 +199,8 @@ struct TensorBase : TBase { } // ILotusValueProviderPrivate::GetOrtValue - STDMETHOD(GetOrtValue) - (WinML::BindingContext& context, OrtValue** ort_value, OrtAllocator** ort_allocator) { - ORT_UNUSED_PARAMETER(ort_allocator); + STDMETHOD(GetValue) + (WinML::BindingContext& context, IValue** out) { RETURN_HR_IF_NULL_MSG( WINML_ERR_INVALID_BINDING, m_resources, @@ -219,10 +210,11 @@ struct TensorBase : TBase { auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); + if (spDevice->IsCpuDevice()) { - *ort_value = CPUTensorize(context).release(); + RETURN_IF_FAILED(CPUTensorize(context, out)); } else { - *ort_value = GPUTensorize(context).release(); + RETURN_IF_FAILED(GPUTensorize(context, out)); } return S_OK; @@ -240,47 +232,88 @@ struct TensorBase : TBase { return size; } + template + void SetBufferFromValueResourceBuffer(uint32_t size, void* data) { + // This adds compile time checks that ensure that the API can only be called when + // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. + ASSERT_TEMPLATE_PARAMETERS(); + + GetCpuResource()->set(size, reinterpret_cast(data)); + } + + template <> + void SetBufferFromValueResourceBuffer(uint32_t size, void* data) { + // Ensure that this call is being called with the correct template parameters + ASSERT_TEMPLATE_PARAMETERS(); + + GetCpuResource()->get_tensor_buffer()->Set(size, reinterpret_cast(data)); + } + + template + HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + // This adds compile time checks that ensure that the API can only be called when + // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. + ASSERT_TEMPLATE_PARAMETERS(); + + RETURN_IF_FAILED_MSG(engine->CreateTensorValueFromExternalBuffer( + GetCpuResource()->buffer().second, GetCpuResource()->size_in_bytes(), GetCpuResource()->shape().data(), + GetCpuResource()->shape().size(), TensorKind(), value), + "Failed to prepare buffer for copy back from device resource."); + return S_OK; + } + + template <> + HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + // Ensure that this call is being called with the correct template parameters + ASSERT_TEMPLATE_PARAMETERS(); + + std::vector raw_values; + auto string_array = GetCpuResource()->buffer().second; + std::transform( + string_array, + string_array + GetCpuResource()->size_in_bytes(), + std::back_inserter(raw_values), + [&](auto& str) { return str.c_str(); }); + + RETURN_IF_FAILED_MSG(engine->CreateStringTensorValueFromDataWithCopy( + raw_values.data(), raw_values.size(), GetCpuResource()->shape().data(), + GetCpuResource()->shape().size(), value), + "Failed to prepare buffer for copy back from device resource."); + return S_OK; + } // ILotusValueProviderPrivate::UpdateSourceResourceData STDMETHOD(UpdateSourceResourceData) - (BindingContext& context, OrtValue* ort_value) { + (BindingContext& context, IValue* value) { RETURN_HR_IF_NULL_MSG( E_ILLEGAL_METHOD_CALL, m_resources, "The tensor has been closed and its resources have been detached during evaluation!"); - // get the mutable raw data buffer - void* pResource = nullptr; - Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(ort_value, &pResource)); + WinML::Resource updated_resource; + RETURN_IF_FAILED(value->GetResource(updated_resource)); // get the shape - Ort::TensorTypeAndShapeInfo type_and_shape(nullptr); - Ort::ThrowOnError(Ort::GetApi().GetTensorTypeAndShape(ort_value, type_and_shape.put())); - shape_ = type_and_shape.GetShape(); + RETURN_IF_FAILED_MSG(value->GetTensorShape(shape_), "Failed to get the tensor shape from resource!"); // make sure we always have a CPU resource if (GetCpuResource() == nullptr) { - GetCpuResource() = std::make_shared>(adapter_.get(), shape_); + GetCpuResource() = std::make_shared>(shape_); } - // get the memory info for the ort value - Ort::MemoryInfo memory_info(nullptr); - RETURN_IF_FAILED(adapter_->GetValueMemoryInfo(ort_value, memory_info.put())); - - // is it from the CPU provider? - if (!strcmp(memory_info.Name(), onnxruntime::CPU) || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUOutput || - memory_info.MemType() == ::OrtMemType::OrtMemTypeCPUInput) { + bool is_cpu; + if (SUCCEEDED(value->IsCpu(&is_cpu)) && is_cpu) { // Get the data pointer and size - T* pData; - uint32_t pSize; - std::tie(pSize, pData) = GetCpuResource()->buffer(); + T* data; + uint32_t size; + std::tie(size, data) = GetCpuResource()->buffer(); - if (pResource != reinterpret_cast(pData)) { + if (updated_resource.get() != reinterpret_cast(data)) { // Only copy the data if the source and destination are not the same! // The engine provided buffer will not match the tensor buffer when // the tensor is created as a placeholder output, or as an unbound output. - GetCpuResource()->set(static_cast(ShapeSize(shape_)), reinterpret_cast(pResource)); + auto shape_size = static_cast(ShapeSize(shape_)); + SetBufferFromValueResourceBuffer(shape_size, updated_resource.get()); } } else { // If we got a gpu resource, we should move the data to the cpu so accessors can retrieve the data. @@ -288,8 +321,12 @@ struct TensorBase : TBase { // resources for tensors. Therefore we are certain that the returned dxresource is the same as the one we passed in // and was updated in place. auto spSession = context.session.as(); - auto cpuValue = GetCpuResource()->GetValue(); - RETURN_IF_FAILED(adapter_->CopyTensor(spSession->GetExecutionProvider(), ort_value, cpuValue)); + auto engine = spSession->GetEngine(); + + winrt::com_ptr dest; + RETURN_IF_FAILED_MSG(CreateTensorValueFromExternalBuffer(engine, dest.put()), + "Failed to prepare buffer for copy back from device resource."); + RETURN_IF_FAILED(engine->CopyValueAcrossDevices(value, dest.get())); } return S_OK; @@ -377,7 +414,7 @@ struct TensorBase : TBase { typename TBase::class_type tensorValue = winrt::make(); auto tensorValueImpl = tensorValue.as(); tensorValueImpl->shape_ = vecShape; - tensorValueImpl->GetCpuResource() = std::make_shared>(tensorValueImpl->adapter_.get(), vecShape, buffer); + tensorValueImpl->GetCpuResource() = std::make_shared>(vecShape, buffer); return tensorValue; } WINML_CATCH_ALL @@ -410,7 +447,7 @@ struct TensorBase : TBase { THROW_HR_IF(E_INVALIDARG, desc.Width < width); // make the underlying winrt object - typename TBase::class_type tensorValue = winrt::make(shapeVector, value, desc.Width); + typename TBase::class_type tensorValue = winrt::make(shapeVector, value); // return it (the caller owns the ref) *result = tensorValue.as().detach(); @@ -496,7 +533,7 @@ struct TensorBase : TBase { // This Api is not supported for TensorString RETURN_HR_IF_MSG( ERROR_INVALID_FUNCTION, - (std::is_same::value), + (std::is_same_v), "TensorString objects cannot return byte buffers!"); RETURN_HR_IF_NULL_MSG( @@ -518,7 +555,7 @@ struct TensorBase : TBase { m_resources, "The tensor has been closed and its resources have been detached!"); - GetGpuResource()->DXResource.copy_to(ppResource); + GetGpuResource().copy_to(ppResource); return S_OK; } WINML_CATCH_ALL_COM @@ -551,12 +588,12 @@ struct TensorBase : TBase { // Specialized version to convert float16 to float template <> - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -567,7 +604,7 @@ struct TensorBase : TBase { floatValue.data(), sizeof(float) /* output stride */, reinterpret_cast(pBuffer), - sizeof(DirectX::PackedVector::HALF) /* input stride */, + sizeof(WinML::Half) /* input stride */, size); // Create IVectorView from copied data. @@ -684,12 +721,12 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromArray(winrt::array_view data) { + void SetBufferFromArray(winrt::array_view data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -697,7 +734,7 @@ struct TensorBase : TBase { THROW_HR_IF(E_UNEXPECTED, data.size() != size); DirectX::PackedVector::XMConvertFloatToHalfStream( reinterpret_cast(pBuffer), - sizeof(DirectX::PackedVector::HALF) /* output stride */, + sizeof(WinML::Half) /* output stride */, data.data(), sizeof(float) /* input stride */, data.size()); @@ -760,13 +797,13 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromIterable( + void SetBufferFromIterable( winrt::Windows::Foundation::Collections::IIterable const& data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; - onnxruntime::MLFloat16* pBuffer; + WinML::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -826,7 +863,7 @@ struct TensorBase : TBase { return m_resources->CpuResource; } - std::shared_ptr& GetGpuResource() { + winrt::com_ptr& GetGpuResource() { WINML_THROW_HR_IF_NULL_MSG( E_ILLEGAL_METHOD_CALL, m_resources, @@ -840,7 +877,6 @@ struct TensorBase : TBase { std::shared_ptr> m_resources; std::vector>> m_outstandingReferences; bool m_isClosed = false; - winrt::com_ptr adapter_; }; } // namespace Windows::AI::MachineLearning diff --git a/winml/lib/Api/impl/TensorBuffer.h b/winml/lib/Api/impl/TensorBuffer.h index 079175fca2a27..d43b61d7cb25a 100644 --- a/winml/lib/Api/impl/TensorBuffer.h +++ b/winml/lib/Api/impl/TensorBuffer.h @@ -133,9 +133,7 @@ class TensorBuffer { return std::make_pair(gsl::narrow_cast(m_buffer.size()), m_buffer.data()); } - // The Set APIs should generally be avoided implemented in the TensorBuffer. - // Callers should generally use the Buffer API and copy directly into it. - auto Set(uint32_t size, const std::string* pData) { + auto Set(uint32_t size, std::string_view* data) { WINML_THROW_HR_IF_FALSE_MSG( E_INVALIDARG, size <= m_buffer.size(), @@ -143,24 +141,8 @@ class TensorBuffer { static_cast(size), static_cast(m_buffer.size())); - std::copy(pData, pData + size, m_buffer.begin()); - } - - auto Set(std::vector&& other) { - auto tensorSize = m_buffer.size(); - - WINML_THROW_HR_IF_FALSE_MSG( - E_INVALIDARG, - other.size() <= tensorSize, - "Vector argument other has size (%d) which is greater than tensor size(%d)", - static_cast(other.size()), - static_cast(tensorSize)); - - if (tensorSize != other.size()) { - other.resize(tensorSize); - } - - m_buffer = std::move(other); + // Copy + std::copy(data, data + size, m_buffer.begin()); } }; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorKindFrom.h b/winml/lib/Api/impl/TensorKindFrom.h index 0a3c8a6a7218e..e9662727f53eb 100644 --- a/winml/lib/Api/impl/TensorKindFrom.h +++ b/winml/lib/Api/impl/TensorKindFrom.h @@ -4,6 +4,13 @@ #pragma once namespace Windows::AI::MachineLearning { + +// We need to define our own type for Half since DirectX::PackedVector::Half resolves to uint16_t per its typedef declaration. +// Templates require an actual type name to resolve correctly. +struct Half { + DirectX::PackedVector::HALF value; +}; + template struct TensorKindFrom {}; template <> @@ -60,12 +67,7 @@ struct TensorKindFrom { static const winml::TensorKind Type = wi template <> struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::String; }; template <> -struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::Float16; }; - -template -struct ONNXTensorElementDataTypeFrom {}; - - +struct TensorKindFrom { static const winml::TensorKind Type = winml::TensorKind::Float16; }; template struct TensorFeatureDescriptorFrom { @@ -75,9 +77,9 @@ struct TensorFeatureDescriptorFrom { return winrt::make( nullptr /* set to null as values are name-less */, nullptr /* set to null as values are description-less */, - false /* set to false as values dont have required annotations */, TensorKindFrom::Type, shape, + false /* set to false as values dont have required annotations */, false /* set to false as this is not a tensor of unsupported metadata */); } }; diff --git a/winml/lib/Api/impl/TensorMemoryBufferReference.h b/winml/lib/Api/impl/TensorMemoryBufferReference.h index bc6234c4e8741..f5df6c47c68c0 100644 --- a/winml/lib/Api/impl/TensorMemoryBufferReference.h +++ b/winml/lib/Api/impl/TensorMemoryBufferReference.h @@ -9,24 +9,6 @@ #include namespace Windows::AI::MachineLearning { -struct DMLResource { - DMLResource(ID3D12Resource* pResource, UINT64 resource_width) { - DXResource.copy_from(pResource); - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); - ExecutionProviderAllocatedResource = adapter_->CreateGPUAllocationFromD3DResource(pResource); - resource_width_ = resource_width; - } - - ~DMLResource() { - adapter_->FreeGPUAllocation(ExecutionProviderAllocatedResource); - } - - winrt::com_ptr DXResource; - UINT64 resource_width_; - void* ExecutionProviderAllocatedResource = nullptr; - winrt::com_ptr adapter_; -}; - template struct TensorResources { // ITensorNative::GetBuffer @@ -36,40 +18,28 @@ struct TensorResources { RETURN_HR_IF_NULL(E_POINTER, value); RETURN_HR_IF_NULL(E_POINTER, capacity); - *value = nullptr; - *capacity = 0; - - // This Api is not supported for TensorString - auto isTensorString = std::is_same::value; - RETURN_HR_IF(ERROR_INVALID_FUNCTION, isTensorString); + RETURN_HR_IF_MSG( + ERROR_INVALID_FUNCTION, + (std::is_same_v), + "TensorString objects cannot return byte buffers!"); try { + *value = nullptr; + *capacity = 0; + // Lazily allocate the cpu resource on call to GetBuffer if (CpuResource == nullptr) { - winrt::com_ptr adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - CpuResource = std::make_shared>(adapter.get(), shape); + CpuResource = std::make_shared>(shape); } - if constexpr (std::is_same_v) { - std::string* pData; - uint32_t pSize; - std::tie(pSize, pData) = CpuResource->buffer(); - - // Set out parameters - *capacity = static_cast(pSize * sizeof(T)); - *value = (BYTE*)pData; - } else { - // Get the data pointer and size - T* pData; - uint32_t pSize; - std::tie(pSize, pData) = CpuResource->buffer(); - - // Set out parameters - *capacity = static_cast(pSize * sizeof(T)); - *value = (BYTE*)pData; - } + // Get the data pointer and size + T* data; + uint32_t size; + std::tie(size, data) = CpuResource->buffer(); + // Set out parameters + *capacity = static_cast(size * sizeof(T)); + *value = (BYTE*)data; return S_OK; } WINML_CATCH_ALL_COM @@ -77,7 +47,7 @@ struct TensorResources { // Theses are access directly by TensorMemoryBufferReference and TensorBase std::shared_ptr> CpuResource; - std::shared_ptr GpuResource; + winrt::com_ptr GpuResource; }; // This class holds onto the lifetime of TensorResources so that they can be kept alive by TensorBase AND its active MBRs. diff --git a/winml/lib/Api/inc/ILotusValueProviderPrivate.h b/winml/lib/Api/inc/ILotusValueProviderPrivate.h index 5ae5adc902a67..3bfc6a2a79961 100644 --- a/winml/lib/Api/inc/ILotusValueProviderPrivate.h +++ b/winml/lib/Api/inc/ILotusValueProviderPrivate.h @@ -3,7 +3,7 @@ #pragma once -#include "WinMLAdapter.h" +#include "iengine.h" // ILotusValueProviderPrivate exposes a private Lotus interface to the engine so that it can retrieve tensor // resources stored in winrt structures. @@ -24,9 +24,9 @@ struct BindingContext { }; struct __declspec(uuid("27e2f437-0112-4693-849e-e04323a620fb")) __declspec(novtable) ILotusValueProviderPrivate : IUnknown { - virtual HRESULT __stdcall GetOrtValue(BindingContext& binding_context, OrtValue** ort_value, OrtAllocator** ort_allocator) = 0; + virtual HRESULT __stdcall GetValue(BindingContext& binding_context, WinML::IValue** out) = 0; virtual HRESULT __stdcall IsPlaceholder(bool* is_placeholder) = 0; - virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, OrtValue* ort_value) = 0; + virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, WinML::IValue* value) = 0; virtual HRESULT __stdcall AbiRepresentation(winrt::Windows::Foundation::IInspectable& abi_representation) = 0; }; diff --git a/winml/lib/Common/inc/PheonixSingleton.h b/winml/lib/Common/inc/PheonixSingleton.h index c3ab8edd821cc..0ab0f21f4cad5 100644 --- a/winml/lib/Common/inc/PheonixSingleton.h +++ b/winml/lib/Common/inc/PheonixSingleton.h @@ -3,8 +3,8 @@ #pragma once -template -std::shared_ptr PheonixSingleton() { +template +std::shared_ptr PheonixSingleton(TArgs&&... args) { static std::weak_ptr instance_; static std::mutex lock_; @@ -13,7 +13,7 @@ std::shared_ptr PheonixSingleton() { return instance; } - auto instance = std::make_shared(); + auto instance = std::make_shared(std::forward(args)...); instance_ = instance; return instance; } \ No newline at end of file diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h new file mode 100644 index 0000000000000..f9c9dd503dc40 --- /dev/null +++ b/winml/lib/Common/inc/iengine.h @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning { + +MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") +IModelInfo : IUnknown { + STDMETHOD(GetAuthor) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetName) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetDomain) + (const char** out, size_t* len) PURE; + + + STDMETHOD(GetDescription) + (const char** out, size_t* len) PURE; + + STDMETHOD(GetVersion) + (int64_t * out) PURE; + + STDMETHOD(GetModelMetadata) + (ABI::Windows::Foundation::Collections::IMapView * *metadata) PURE; + + STDMETHOD(GetInputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView * *features) PURE; + + STDMETHOD(GetOutputFeatures) + (ABI::Windows::Foundation::Collections::IVectorView * *features) PURE; +}; + +MIDL_INTERFACE("1b198b76-5c44-480d-837c-8433ca6eaf99") +IModel : IUnknown { + STDMETHOD(GetModelInfo) + (IModelInfo * *info) PURE; + + STDMETHOD(ModelEnsureNoFloat16) + () PURE; + + STDMETHOD(CloneModel) + (IModel * *copy) PURE; +}; + +using Resource = std::unique_ptr>; +MIDL_INTERFACE("31f39226-cfe8-4758-af38-3d01b2a33ee1") +IValue : IUnknown { + STDMETHOD(IsEmpty) + (bool* out) PURE; + + STDMETHOD(IsCpu) + (bool* out) PURE; + + STDMETHOD(GetResource) + (WinML::Resource & resource) PURE; + + STDMETHOD(IsTensor) + (bool* out) PURE; + + STDMETHOD(IsOfTensorType) + (winml::TensorKind kind, bool* out) PURE; + + STDMETHOD(GetTensorShape) + (std::vector & shape_vector) PURE; + + STDMETHOD(IsOfMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; + + STDMETHOD(IsOfVectorMapType) + (winml::TensorKind key_kind, winml::TensorKind value_kind, bool* out) PURE; +}; + +MIDL_INTERFACE("30c99886-38d2-41cb-a615-203fe7d7daac") +IEngine : IUnknown { + STDMETHOD(LoadModel) + (_In_ IModel*) PURE; + + STDMETHOD(Initialize) + () PURE; + + STDMETHOD(RegisterGraphTransformers) + () PURE; + + STDMETHOD(RegisterCustomRegistry) + (IMLOperatorRegistry * registry) PURE; + + STDMETHOD(EndProfiling) + () PURE; + + STDMETHOD(StartProfiling) + () PURE; + + STDMETHOD(FlushContext) + () PURE; + + STDMETHOD(TrimUploadHeap) + () PURE; + + STDMETHOD(ReleaseCompletedReferences) + () PURE; + + STDMETHOD(Sync) + () PURE; + + STDMETHOD(CreateTensorValue) + (const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateTensorValueFromExternalD3DResource) + (ID3D12Resource * resource, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateTensorValueFromExternalBuffer) + (void* data, size_t size_in_bytes, const int64_t* shape, size_t count, winml::TensorKind kind, _Out_ IValue** out) PURE; + + STDMETHOD(CreateStringTensorValueFromDataWithCopy) + (const char* const* data, size_t num_elements, const int64_t* shape, size_t count, _Out_ IValue** out) PURE; + + STDMETHOD(CreateNullValue) + (_Out_ IValue * *out) PURE; + + STDMETHOD(CreateMapValue) + (IInspectable * map, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue * *out) PURE; + + STDMETHOD(CreateSequenceOfMapsValue) + (IInspectable * sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, _Out_ IValue * *out) PURE; + + STDMETHOD(CreateOneInputAcrossDevices) + (const char* name, IValue* src, IValue** dest) PURE; + + STDMETHOD(CopyValueAcrossDevices) + (IValue * src, IValue * dest) PURE; + + STDMETHOD(Run) + (const char** input_names, IValue** inputs, size_t num_inputs, const char** output_names, IValue** outputs, size_t num_outputs) PURE; + + STDMETHOD(FillFromMapValue) + (IInspectable * map, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue * value) PURE; + + STDMETHOD(FillSequenceOfMapsValue) + (IInspectable * sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue * value) PURE; +}; + +MIDL_INTERFACE("0452ef15-b66b-47ca-9eff-aedac571764e") +IEngineBuilder : IUnknown { + STDMETHOD(SetD3D12Resources) + (ID3D12Device * device, ID3D12CommandQueue * queue) PURE; + + STDMETHOD(GetD3D12Device) + (ID3D12Device * *device) PURE; + + STDMETHOD(GetID3D12CommandQueue) + (ID3D12CommandQueue * *queue) PURE; + + STDMETHOD(SetBatchSizeOverride) + (uint32_t batch_size_override) PURE; + + STDMETHOD(CreateEngine) + (IEngine * *out) PURE; +}; + +MIDL_INTERFACE("5eddd25a-70ad-46ef-a445-78fbaf792c2f") +IEngineFactory : IUnknown { + STDMETHOD(CreateModel) + (_In_ const char* model_path, _In_ size_t len, _Outptr_ IModel** out) PURE; + + STDMETHOD(CreateModel) + (_In_ void* data, _In_ size_t size, _Outptr_ IModel** out) PURE; + + STDMETHOD(CreateEngineBuilder) + (IEngineBuilder * *engine_builder) PURE; + + STDMETHOD(EnableDebugOutput) + (bool is_enabled) PURE; + + STDMETHOD(CreateCustomRegistry) + (_Out_ IMLOperatorRegistry * *registry) PURE; +}; + +} // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Common/inc/onnx.h b/winml/lib/Common/inc/onnx.h index d537ba450b9d8..0db64211dc8e0 100644 --- a/winml/lib/Common/inc/onnx.h +++ b/winml/lib/Common/inc/onnx.h @@ -13,9 +13,6 @@ // Restore ERROR define #define ERROR 0 -// the C++ ort api -#include "core/session/onnxruntime_cxx_api.h" - #ifdef USE_DML #include #endif USE_DML diff --git a/winml/test/api/LearningModelBindingAPITest.cpp b/winml/test/api/LearningModelBindingAPITest.cpp index d67013dfc1db4..2f322987f9be7 100644 --- a/winml/test/api/LearningModelBindingAPITest.cpp +++ b/winml/test/api/LearningModelBindingAPITest.cpp @@ -154,6 +154,8 @@ static void DictionaryVectorizerMapString() WINML_EXPECT_TRUE(first.Current().Key() == mapInputName); WINML_EXPECT_TRUE(first.Current().Value() == mapInputInspectable); WINML_EXPECT_TRUE(binding.Lookup(mapInputName) == mapInputInspectable); + + modelSession.Evaluate(binding, L""); } static void RunZipMapInt64(