0. In the future, model_names and versions will be verified.
-## Request and Response Payload
+### Request and Response Payload
-An HTTP request can be a Protobuf message in two formats: binary or JSON. The HTTP request header field `Content-Type` tells the server how to handle the request and thus it is mandatory for all requests. Requests missing `Content-Type` will be rejected as `400 Bad Request`.
+The request and response need to be a protobuf message. The Protobuf definition can be found [here](https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/server/protobuf/predict.proto).
+
+A protobuf message could have two formats: binary and JSON. Usually the binary payload has better latency, in the meanwhile the JSON format is easy for human readability.
+
+The HTTP request header field `Content-Type` tells the server how to handle the request and thus it is mandatory for all requests. Requests missing `Content-Type` will be rejected as `400 Bad Request`.
* For `"Content-Type: application/json"`, the payload will be deserialized as JSON string in UTF-8 format
* For `"Content-Type: application/vnd.google.protobuf"`, `"Content-Type: application/x-protobuf"` or `"Content-Type: application/octet-stream"`, the payload will be consumed as protobuf message directly.
-The Protobuf definition can be found [here](https://github.com/Microsoft/onnxruntime/blob/master/onnxruntime/server/protobuf/predict.proto).
+Clients can control the response type by setting the request with an `Accept` header field and the server will serialize in your desired format. The choices currently available are the same as the `Content-Type` header field. If this field is not set in the request, the server will use the same type as your request.
-## Inferencing
+### Inferencing
To send a request to the server, you can use any tool which supports making HTTP requests. Here is an example using `curl`:
@@ -60,11 +68,17 @@ or
curl -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
```
-Clients can control the response type by setting the request with an `Accept` header field and the server will serialize in your desired format. The choices currently available are the same as the `Content-Type` header field.
+### Interactive tutorial notebook
+
+A simple Jupyter notebook demonstrating the usage of ONNX Runtime server to host an ONNX model and perform inferencing can be found [here](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxRuntimeServerSSDModel.ipynb).
+
+## GRPC Endpoint
+
+If you prefer using the GRPC endpoint, the protobuf could be found [here](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/server/protobuf/prediction_service.proto). You could generate your client and make a GRPC call to it. To learn more about how to generate the client code and call to the server, please refer to [the tutorials of GRPC](https://grpc.io/docs/tutorials/).
## Advanced Topics
-### Number of HTTP Threads
+### Number of Worker Threads
You can change this to optimize server utilization. The default is the number of CPU cores on the host machine.
@@ -75,66 +89,11 @@ For easy tracking of requests, we provide the following header fields:
* `x-ms-request-id`: will be in the response header, no matter the request result. It will be a GUID/uuid with dash, e.g. `72b68108-18a4-493c-ac75-d0abd82f0a11`. If the request headers contain this field, the value will be ignored.
* `x-ms-client-request-id`: a field for clients to tracking their requests. The content will persist in the response headers.
-Here is an example of a client sending a request:
-
-#### Client Side
+### rsyslog Support
-```
-$ curl -v -X POST --data-binary "@predict_request_0.pb" -H "Content-Type: application/octet-stream" -H "Foo: 1234" -H "x-ms-client-request-id: my-request-001" -H "Accept: application/json" http://127.0.0.1:8001/v1/models/mymodel/versions/3:predict
-Note: Unnecessary use of -X or --request, POST is already inferred.
-* Trying 127.0.0.1...
-* Connected to 127.0.0.1 (127.0.0.1) port 8001 (#0)
-> POST /v1/models/mymodel/versions/3:predict HTTP/1.1
-> Host: 127.0.0.1:8001
-> User-Agent: curl/7.47.0
-> Content-Type: application/octet-stream
-> x-ms-client-request-id: my-request-001
-> Accept: application/json
-> Content-Length: 3179
-> Expect: 100-continue
->
-* Done waiting for 100-continue
-* We are completely uploaded and fine
-< HTTP/1.1 200 OK
-< Content-Type: application/json
-< x-ms-request-id: 72b68108-18a4-493c-ac75-d0abd82f0a11
-< x-ms-client-request-id: my-request-001
-< Content-Length: 159
-<
-* Connection #0 to host 127.0.0.1 left intact
-{"outputs":{"Sample_Output_Name":{"dims":["1","10"],"dataType":1,"rawData":"6OpzRFquGsSFdM1FyAEnRFtRZcRa9NDEUBj0xI4ydsJIS0LE//CzxA==","dataLocation":"DEFAULT"}}}%
-```
+If you prefer using an ONNX Runtime Server with [rsyslog](https://www.rsyslog.com/) support([build instruction](https://github.com/microsoft/onnxruntime/blob/master/BUILD.md#build-onnx-runtime-server-on-linux)), you should be able to see the log in `/var/log/syslog` after the ONNX Runtime Server runs. For detail about how to use rsyslog, please reference [here](https://www.rsyslog.com/category/guides-for-rsyslog/).
-#### Server Side
+## Report Issues
-And here is what the output on the server side looks like with logging level of verbose:
+If you see any issues or want to ask questions about the server, please feel free to do so in this repo with the version and commit id from the command line.
-```
-2019-04-04 23:48:26.395200744 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:40 Predict] Name: mymodel Version: 3 Action: predict
-2019-04-04 23:48:26.395289437 [V:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, predict_request_handler.cc:46 Predict] x-ms-client-request-id: [my-request-001]
-2019-04-04 23:48:26.395540707 [I:onnxruntime:InferenceSession, inference_session.cc:736 Run] Running with tag: 72b68108-18a4-493c-ac75-d0abd82f0a11
-2019-04-04 23:48:26.395596858 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, inference_session.cc:976 CreateLoggerForRun] Created logger for run with id of 72b68108-18a4-493c-ac75-d0abd82f0a11
-2019-04-04 23:48:26.395731391 [I:onnxruntime:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:42 Execute] Begin execution
-2019-04-04 23:48:26.395763319 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:45 Execute] Size of execution plan vector: 12
-2019-04-04 23:48:26.396228981 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution28
-2019-04-04 23:48:26.396580161 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus30
-2019-04-04 23:48:26.396623732 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 10
-2019-04-04 23:48:26.396878822 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU32
-2019-04-04 23:48:26.397091882 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling66
-2019-04-04 23:48:26.397126243 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 11
-2019-04-04 23:48:26.397772701 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Convolution110
-2019-04-04 23:48:26.397818174 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 13
-2019-04-04 23:48:26.398060592 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus112
-2019-04-04 23:48:26.398095300 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 14
-2019-04-04 23:48:26.398257563 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: ReLU114
-2019-04-04 23:48:26.398426740 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Pooling160
-2019-04-04 23:48:26.398466031 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 15
-2019-04-04 23:48:26.398542823 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape0
-2019-04-04 23:48:26.398599687 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212_reshape1
-2019-04-04 23:48:26.398692631 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Times212
-2019-04-04 23:48:26.398731471 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 17
-2019-04-04 23:48:26.398832735 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:156 Execute] Releasing node ML values after computing kernel: Plus214
-2019-04-04 23:48:26.398873229 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:197 ReleaseNodeMLValues] Releasing mlvalue with index: 19
-2019-04-04 23:48:26.398922929 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:160 Execute] Fetching output.
-2019-04-04 23:48:26.398956560 [V:VLOG1:72b68108-18a4-493c-ac75-d0abd82f0a11, sequential_executor.cc:163 Execute] Done with execution.
-```
\ No newline at end of file
diff --git a/docs/PyOp.md b/docs/PyOp.md
index 1d0e5e27bcd42..e82d7b6995b25 100644
--- a/docs/PyOp.md
+++ b/docs/PyOp.md
@@ -1,10 +1,10 @@
# Python Operator
-To facilitate Python coders on model developing, onnxruntime provides a way to invoke operators implemented in Python.
+The Python Operator provides the capability to easily invoke any custom Python code within a single node of an ONNX graph using ONNX Runtime. This can be useful for quicker experimentation when a model requires operators that are not officially supported in ONNX and ONNX Runtime, particularly if there is already a Python implementation for the required functionality. This should be used with discretion in production scenarios, and all security or other risks should be considered.
-## Implemenation
-The feature is implemented under onnxruntime/core/language_interop_ops.
+## Design Overview
+The feature can be found under [onnxruntime/core/language_interop_ops](../onnxruntime/core/language_interop_ops).
All Python C API dependent code are compiled into a dynamic linked library named pywrapper.
-Before calling into Python script, pywrapper will convert onnxruntime tensor(s) to numpy(s), which get converted back when done.
+Before calling into Python script, pywrapper will convert onnxruntime tensor(s) to numpy(s), which is converted back when completed.
Here is a chart illustrating the calling sequence:
onnxruntime pywrapper script
@@ -13,18 +13,20 @@ onnxruntime pywrapper script
| call with tensor(s) | ------------------------------> |
| | call with numpy(s) |
| | | compute
- | | <----------------------------- |
+ | | <------------------------------ |
| <------------------------------ | return numpys(s) |
| return tensor(s) | |
-## Usage
-Step 1, build onnxruntime with“--config Release --enable_language_interop_ops --build_shared_lib” and override existing onnxruntime binary with the latest, then copy onnxruntime_pywrapper.dll or libonnxruntime_pywrapper.so or libonnxruntime_pywrapper.dylib to the path where onnxruntime binary is placed.
-Note:
-* It is suggested to compile within the Python environment where inferencing will happen. For example, if inferencing will happen in a conda env named myconda1, please compile the binary within that environment as well;
-* If "--numpy_version=..." is specified, Python operator will build with that version.
+## How to Use
+### Step 1
+Build onnxruntime with `--config Release --enable_language_interop_ops --build_shared_lib` and override the existing onnxruntime binary with the latest. Then, copy onnxruntime_pywrapper.dll, libonnxruntime_pywrapper.so, or libonnxruntime_pywrapper.dylib to the path where the onnxruntime binary is located.
+**Notes:**
+* It is recommended to compile within the Python environment where inferencing will happen. For example, if inferencing will happen in a conda env named myconda1, please compile the binary within that environment as well
+* If `--numpy_version=...` is specified, the Python operator will build with that version.
-Step 2, create an onnx model containing Python operator nodes:
+### Step 2
+Create an onnx model containing Python operator nodes:
```python
ad1_node = helper.make_node('Add', ['A','B'], ['S'])
mul_node = helper.make_node('Mul', ['C','D'], ['P'])
@@ -48,7 +50,8 @@ graph = helper.make_graph([ad1_node,mul_node,py1_node,ad2_node,py2_node,sub_node
model = helper.make_model(graph, producer_name = 'pyop_model')
onnx.save(model, './model.onnx')
```
-Step 3, implement mymodule.py:
+### Step 3
+Implement mymodule.py:
```python
class Multi_1:
def __init__(self, W1, W2, W3):
@@ -63,23 +66,24 @@ class Multi_2:
r1, r2 = H + N, N + E
return r1, r2
```
-Step 4, copy mymodule.py into Python sys.path, then reference with onnxruntime. On Windows, please set PYTHONHOME beforehand. It should point to directory where the python is installed, such as C:\Python37 or C:\ProgramData\Anaconda3\envs\myconda1 if it is in conda.
+### Step 4
+Copy mymodule.py into Python sys.path, then reference with onnxruntime. On Windows, please set PYTHONHOME beforehand. It should point to directory where the python is installed, such as C:\Python37 or C:\ProgramData\Anaconda3\envs\myconda1 if it is in conda.
## Supported Data Types
-* TensorProto.BOOL,
-* TensorProto.UINT8,
-* TensorProto.UINT16,
-* TensorProto.UINT32,
-* TensorProto.INT16,
-* TensorProto.INT32,
-* TensorProto.FLOAT,
+* TensorProto.BOOL
+* TensorProto.UINT8
+* TensorProto.UINT16
+* TensorProto.UINT32
+* TensorProto.INT16
+* TensorProto.INT32
+* TensorProto.FLOAT
* TensorProto.DOUBLE
## Limitations
-* On Windows, "--config Debug" has known issues, build with "--config RelWithDebInfo" if need debugging symbols;
-* Due to python C API restrictions, multi-threading is disabled, meaning Python operators will run sequentially.
+* On Windows, `--config Debug` has known issues. Please build with `--config RelWithDebInfo` if debugging symbols are needed.
+* Due to Python C API restrictions, multi-threading is disabled so Python operators will run sequentially.
-## Test
+## Test Coverage
The operator has been tested on multiple platforms, with or without conda:
Platform | Python 3.5 | Python 3.6 | Python 3.7
@@ -88,3 +92,47 @@ Windows | (conda) passed | (conda) passed | passed
Linux | (conda) passed | (conda) passed | passed
Mac | (conda) passed | (conda) passed | (conda) passed
+## Example
+Developers could resort to PyOp during model conversion for missing operators:
+```python
+import os
+import numpy as np
+from onnx import *
+from skl2onnx import convert_sklearn
+from skl2onnx.common.data_types import FloatTensorType
+from skl2onnx.common.utils import check_input_and_output_numbers
+
+X = np.array([[1, 1], [2, 1], [3, 1.2], [4, 1], [5, 0.8], [6, 1]],dtype=np.single)
+nmf = NMF(n_components=2, init='random', random_state=0)
+W = np.array(nmf.fit_transform(X), dtype=np.single)
+
+def calculate_sklearn_nmf_output_shapes(operator):
+ check_input_and_output_numbers(operator, output_count_range=1, input_count_range=1)
+ operator.outputs[0].type.shape = operator.inputs[0].type.shape
+
+def convert_nmf(scope, operator, container):
+ ws = [str(w) for w in W.flatten()]
+ attrs = {'W':'|'.join(ws)}
+ container.add_node(op_type='PyOp', name='nmf', inputs=['X'], outputs=['variable'],
+ op_version=10, op_domain='MyDomain', module='mymodule', class_name='MyNmf',
+ input_types=[TensorProto.FLOAT], output_types=[TensorProto.FLOAT], **attrs)
+
+custom_shape_calculators = {type(nmf): calculate_sklearn_nmf_output_shapes}
+custom_conversion_functions = {type(nmf): convert_nmf}
+initial_types = [('X', FloatTensorType([6,2]))]
+onx = convert_sklearn(nmf, '', initial_types, '', None, custom_conversion_functions, custom_shape_calculators)
+with th open("model.onnx", "wb") as f:
+ f.write(onx.SerializeToString())
+```
+mymodule.py:
+```python
+import numpy as np
+class MyNmf:
+ def __init__(self,W):
+ A = []
+ for w in W.split('|'):
+ A.append(float(w))
+ self.__W = np.array(A,dtype=np.single).reshape(6,2)
+ def compute(self,X):
+ return self.__W
+```
diff --git a/docs/Versioning.md b/docs/Versioning.md
index cf503df1d820a..d646d777d8335 100644
--- a/docs/Versioning.md
+++ b/docs/Versioning.md
@@ -7,12 +7,13 @@ same as what is described in the semantic versioning doc linked above.
## Current stable release version
The version number of the current stable release can be found
-[here](../VERSION_NUMBER)
+[here](../VERSION_NUMBER).
## Release cadence
See [Release Management](ReleaseManagement.md)
-## Compatibility with ONNX opsets
+# Compatibility
+## ONNX Compatibility
ONNX Runtime supports both backwards and forward compatibility.
### Backwards compatibility
@@ -26,14 +27,31 @@ the model doesn't use ops that were newly introduced in opset ver 9.
### Version matrix
Following table summarizes the relationship between the ONNX Runtime version and the ONNX
-opset version implemented in that release.
-
-| ONNX Runtime release version | ONNX opset version
implemented in this release | ONNX ML opset version
implemented in this release | Supported ONNX IR version |
-|------------------------------|--------------------|----------------------|------------------|
-| 0.4.0 | 10 | 1 | 5 |
-| 0.3.1 | 9 | 1 | 3 |
-| 0.3.0 | 9 | 1 | 3 |
-| 0.2.1 | 8 | 1 | 3 |
-| 0.2.0 | 8 | 1 | 3 |
-| 0.1.5 | 8 | 1 | 3 |
-| 0.1.4 | 8 | 1 | 3 |
+opset version implemented in that release. Please note the Backwards and Forward compatiblity notes above.
+For more details on ONNX Release versions, see [this page](https://github.com/onnx/onnx/blob/master/docs/Versioning.md).
+
+| ONNX Runtime release version | ONNX release version | ONNX opset version | ONNX ML opset version | Supported ONNX IR version | [WinML compatibility](https://docs.microsoft.com/en-us/windows/ai/windows-ml/)|
+|------------------------------|--------------------|--------------------|----------------------|------------------|------------------|
+| 0.5.0 | 1.5 | 10 | 1 | 5 | -- |
+| 0.4.0 | 1.5 | 10 | 1 | 5 | -- |
+| 0.3.1
0.3.0 | 1.4 | 9 | 1 | 3 | -- |
+| 0.2.1
0.2.0 | 1.3 | 8 | 1 | 3 | 1903 (19H1)+ |
+| 0.1.5
0.1.4 | 1.3 | 8 | 1 | 3 | 1809 (RS5)+ |
+
+
+## Tool Compatibility
+A variety of tools can be used to create ONNX models. Unless otherwise noted, please use the latest released version of the tools to convert/export the ONNX model. Many tools are backwards compatible and support multiple ONNX versions. Join this with the table above to evaluate ONNX Runtime compatibility.
+
+
+|Tool|Recommended Version|Supported ONNX version(s)|
+|---|---|---|
+|[PyTorch](https://pytorch.org/)|[Latest stable](https://pytorch.org/get-started/locally/)|1.2-1.5*
*may require [ONNX version converter](https://github.com/onnx/onnx/blob/master/docs/VersionConverter.md) to convert to desired opset #*|
+|[ONNXMLTools](https://pypi.org/project/onnxmltools/)
CoreML, LightGBM, XGBoost, LibSVM|[Latest stable](https://github.com/onnx/onnxmltools/releases)|1.2-1.5|
+|[ONNXMLTools](https://pypi.org/project/onnxmltools/)
SparkML|[Latest stable](https://github.com/onnx/onnxmltools/releases)|1.4-1.5|
+|[SKLearn-ONNX](https://pypi.org/project/skl2onnx/)|[Latest stable](https://github.com/onnx/sklearn-onnx/releases)|1.2-1.5|
+|[Keras-ONNX](https://pypi.org/project/keras2onnx/)|[Latest stable](https://github.com/onnx/keras-onnx/releases)|1.2-1.5|
+|[Tensorflow-ONNX](https://pypi.org/project/tf2onnx/)|[Latest stable](https://github.com/onnx/tensorflow-onnx/releases)|1.2-1.5|
+|[WinMLTools](https://docs.microsoft.com/en-us/windows/ai/windows-ml/convert-model-winmltools)|[Latest stable](https://pypi.org/project/winmltools/)|1.2-1.4|
+|[AutoML](https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-automated-ml)|[1.0.39+](https://pypi.org/project/azureml-automl-core)|1.5|
+| |[1.0.33](https://pypi.org/project/azureml-automl-core/1.0.33/)|1.4|
+
diff --git a/docs/execution_providers/OpenVINO-ExecutionProvider.md b/docs/execution_providers/OpenVINO-ExecutionProvider.md
index 5cfa516af3dca..1d5838268d3f6 100644
--- a/docs/execution_providers/OpenVINO-ExecutionProvider.md
+++ b/docs/execution_providers/OpenVINO-ExecutionProvider.md
@@ -6,9 +6,9 @@ OpenVINO Execution Provider enables deep learning inference on Intel CPUs, Intel
Below table shows the ONNX layers supported using OpenVINO Execution Provider and the mapping between ONNX layers and OpenVINO layers. The below table also lists the Intel hardware support for each of the layers. CPU refers to Intel®
Atom, Core, and Xeon processors. GPU refers to the Intel Integrated Graphics. VPU refers to USB based Intel® MovidiusTM
-VPUs as well as Intel® Vision accelerator Design with Intel Movidius TM MyriadX VPU.
+VPUs as well as Intel® Vision accelerator Design with Intel Movidius TM MyriadX VPU.
-| **ONNX Layers** | **OpenVINO Layers** | **CPU** | **GPU** | **VPU** |
+| **ONNX Layers** | **OpenVINO Layers** | **CPU** | **GPU** | **VPU** |
| --- | --- | --- | --- | --- |
| Add | Eltwise (operation=sum) | Yes | Yes | Yes
| AveragePool | Pooling(pool\_method=avg) | Yes | Yes | Yes
@@ -33,7 +33,7 @@ VPUs as well as Intel® Vision accelerator Design with Intel Movidiu
| UnSqueeze | Reshape | Yes | Yes | Yes
| LeakyRelu | ReLU | Yes | Yes | Yes
-*MatMul is supported in GPU only when the following layer is an Add layer in the topology.
+*MatMul is supported in GPU only when the following layer is an Add layer in the topology.
# Topology Support
@@ -41,17 +41,17 @@ Below topologies are supported from ONNX open model zoo using OpenVINO Execution
## Image Classification Networks
-| **Topology** | **CPU** | **GPU** | **VPU** |
-| --- | --- | --- | --- |
+| **Topology** | **CPU** | **GPU** | **VPU** |
+| --- | --- | --- | --- |
| bvlc\_alexnet | Yes | Yes | Yes
| bvlc\_googlenet | Yes | Yes | Yes
-| bvlc\_reference\_caffenet | Yes | Yes | Yes
-| bvlc\_reference\_rcnn\_ilsvrc13 | Yes | Yes | Yes
+| bvlc\_reference\_caffenet | Yes | Yes | Yes
+| bvlc\_reference\_rcnn\_ilsvrc13 | Yes | Yes | Yes
| densenet121 | Yes | Yes | Yes
-| Inception\_v1 | Yes | Yes | No
+| Inception\_v1 | Yes | Yes | Yes**
| Inception\_v2 | Yes | Yes | Yes
| Shufflenet | Yes | Yes | Yes
-| Zfnet512 | Yes | Yes | Yes
+| Zfnet512 | Yes | Yes | Yes
| Squeeznet 1.1 | Yes | Yes | Yes
| Resnet18v1 | Yes | Yes | Yes
| Resnet34v1 | Yes | Yes | Yes
@@ -62,29 +62,32 @@ Below topologies are supported from ONNX open model zoo using OpenVINO Execution
| Resnet34v2 | Yes | Yes | Yes
| Resnet50v2 | Yes | Yes | Yes
| Resnet101v2 | Yes | Yes | Yes
-| Resnet152v2 | Yes | Yes | Yes
+| Resnet152v2 | Yes | Yes | Yes
| Mobilenetv2 | Yes | Yes | Yes
| vgg16 | Yes | Yes | Yes
| vgg19 | Yes | Yes | Yes
+
## Image Recognition Networks
-| **Topology** | **CPU** | **GPU** | **VPU** |
-| --- | --- | --- | --- |
-| MNIST | Yes | Yes | No
+| **Topology** | **CPU** | **GPU** | **VPU** |
+| --- | --- | --- | --- |
+| MNIST | Yes | Yes | Yes**
+
+**Inception_v1 and MNIST are supported in OpenVINO R1.1 and are not supported in OpenVINO R5.0.1.
## Object Detection Networks
-| **Topology** | **CPU** | **GPU** | **VPU** |
-| --- | --- | --- | --- |
+| **Topology** | **CPU** | **GPU** | **VPU** |
+| --- | --- | --- | --- |
|TinyYOLOv2 | Yes | Yes | Yes
-| ResNet101\_DUC\_HDC | Yes | Yes | No
+| ResNet101\_DUC\_HDC | Yes | No | No
-# Application code changes for VAD-R performance scaling
+# Application code changes for VAD-M performance scaling
-VAD-R has 8 VPUs and is suitable for applications that require multiple inferences to run in parallel. We use batching approach for performance scaling on VAD-R.
+VAD-M has 8 VPUs and is suitable for applications that require multiple inferences to run in parallel. We use batching approach for performance scaling on VAD-M.
-Below python code snippets provide sample classification code to batch input images, load a model and process the output results.
+Below python code snippets provide sample classification code to batch input images, load a model and process the output results.
~~~
import onnxruntime as rt
@@ -95,7 +98,7 @@ import sys
import cv2
import numpy
import time
-import glob
+import glob
~~~
### Load the input onnx model
@@ -111,19 +114,19 @@ for i in range(iters):
images = [cv2.imread(file) for file in glob.glob(str(sys.argv[2])+'/*.jpg')]
for img in images:
# resizing the image
- img = cv2.resize(img, (224,224))
- # convert image to numpy
- x = numpy.asarray(img).astype(numpy.float32)
- x = numpy.transpose(x, (2,0,1))
+ img = cv2.resize(img, (224,224))
+ # convert image to numpy
+ x = numpy.asarray(img).astype(numpy.float32)
+ x = numpy.transpose(x, (2,0,1))
# expand the dimension and batch the images
- x = numpy.expand_dims(x,axis=0)
- if y is None:
- y = x
- else:
- y = numpy.concatenate((y,x), axis=0)
+ x = numpy.expand_dims(x,axis=0)
+ if y is None:
+ y = x
+ else:
+ y = numpy.concatenate((y,x), axis=0)
~~~
-### Start Inference
+### Start Inference
~~~
res = sess.run([sess.get_outputs()[0].name], {sess.get_inputs()[0].name: y})
~~~
diff --git a/docs/python/README.rst b/docs/python/README.rst
index 756383579ee45..0fe76b1624ef3 100644
--- a/docs/python/README.rst
+++ b/docs/python/README.rst
@@ -52,6 +52,11 @@ replaces *scikit-learn* to compute the predictions.
Changes
-------
+0.5.0
+^^^^^
+
+Release Notes : https://github.com/Microsoft/onnxruntime/releases/tag/v0.5.0
+
0.4.0
^^^^^
diff --git a/docs/python/examples/plot_pipeline.py b/docs/python/examples/plot_pipeline.py
index 5063479492429..0a002f6223e1b 100644
--- a/docs/python/examples/plot_pipeline.py
+++ b/docs/python/examples/plot_pipeline.py
@@ -21,7 +21,7 @@
"""
from onnxruntime.datasets import get_example
-example1 = get_example("mul_1.pb")
+example1 = get_example("mul_1.onnx")
import onnx
model = onnx.load(example1) # model is a ModelProto protobuf message
diff --git a/docs/python/examples/plot_profiling.py b/docs/python/examples/plot_profiling.py
index 3844962033f9d..d5617d41726c5 100644
--- a/docs/python/examples/plot_profiling.py
+++ b/docs/python/examples/plot_profiling.py
@@ -19,7 +19,7 @@
#########################
# Let's load a very simple model and compute some prediction.
-example1 = get_example("mul_1.pb")
+example1 = get_example("mul_1.onnx")
sess = rt.InferenceSession(example1)
input_name = sess.get_inputs()[0].name
diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h
index 462aed63f1d68..8a37553ea976b 100644
--- a/include/onnxruntime/core/framework/allocator.h
+++ b/include/onnxruntime/core/framework/allocator.h
@@ -15,21 +15,80 @@
#include "core/framework/fence.h"
#include "core/session/onnxruntime_c_api.h"
+// Struct to represent a physical device.
+struct OrtDevice {
+ using DeviceType = int8_t;
+ using MemoryType = int8_t;
+ using DeviceId = int16_t;
+
+ // Pre-defined device types.
+ static const DeviceType CPU = 0;
+ static const DeviceType GPU = 1; //CUDA
+ static const DeviceType FPGA = 2;
+
+ struct MemType {
+ // Pre-defined memory types.
+ static const MemoryType DEFAULT = 0;
+ static const MemoryType CUDA_PINNED = 1;
+ };
+
+ constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
+ : device_type(device_type_),
+ memory_type(memory_type_),
+ device_id(device_id_) {}
+
+ constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
+
+ DeviceType Type() const {
+ return device_type;
+ }
+
+ MemoryType MemType() const {
+ return memory_type;
+ }
+
+ DeviceId Id() const {
+ return device_id;
+ }
+
+ std::string ToString() const {
+ std::ostringstream ostr;
+ ostr << "Device: ["
+ << " type:" << static_cast(device_type)
+ << " memory_type:" << static_cast(memory_type)
+ << " device_id:" << device_id
+ << "]";
+ return ostr.str();
+ }
+
+ private:
+ // Device type.
+ DeviceType device_type;
+
+ // Memory type.
+ MemoryType memory_type;
+
+ // Device index.
+ DeviceId device_id;
+};
+
struct OrtAllocatorInfo {
// use string for name, so we could have customized allocator in execution provider.
const char* name;
int id;
OrtMemType mem_type;
OrtAllocatorType type;
+ OrtDevice device;
- constexpr OrtAllocatorInfo(const char* name_, OrtAllocatorType type_, int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
+ constexpr OrtAllocatorInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
#if (defined(__GNUC__) || defined(__clang__))
__attribute__((nonnull))
#endif
: name(name_),
id(id_),
mem_type(mem_type_),
- type(type_) {
+ type(type_),
+ device(device_) {
}
// To make OrtAllocatorInfo become a valid key in std map
@@ -67,6 +126,8 @@ std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info);
namespace onnxruntime {
constexpr const char* CPU = "Cpu";
+constexpr const char* CUDA = "Cuda";
+constexpr const char* CUDA_PINNED = "CudaPinned";
// forward declaration
class SessionState;
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index d3b1aa5d75020..6e7c919601060 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -84,20 +84,6 @@ class IExecutionProvider {
*/
virtual std::shared_ptr GetKernelRegistry() const;
- /**
- * Copy tensor between execution providers. It's always a deep copy
- * Either src.location is CPU, or dst.location is CPU. They can't be both on CPU.
- */
- virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const = 0;
-
- /**
- * Copy tensor between execution providers on specified exec queue
- * It's always a deep copy
- * Either src.location is CPU, or dst.location is CPU. They can't be both on CPU.
- */
- virtual common::Status CopyTensor(const Tensor& src, Tensor& dst,
- int exec_queue_id) const;
-
/**
Returns an opaque handle whose exact type varies based on the provider
and is interpreted accordingly by the corresponding kernel implementation.
diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h
index b7b41ebcf70c9..3a0d35e298f98 100644
--- a/include/onnxruntime/core/framework/kernel_registry.h
+++ b/include/onnxruntime/core/framework/kernel_registry.h
@@ -24,9 +24,12 @@ class KernelRegistry {
// for its clients unless the factory is managing the lifecycle of the pointer
// itself.
// TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent
- Status TryCreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider,
- const std::unordered_map& initialized_tensors,
- const OrtValueNameIdxMap& mlvalue_name_idx_map, const FuncManager& funcs_mgr,
+ Status TryCreateKernel(const onnxruntime::Node& node,
+ const IExecutionProvider& execution_provider,
+ const std::unordered_map& constant_initialized_tensors,
+ const OrtValueNameIdxMap& mlvalue_name_idx_map,
+ const FuncManager& funcs_mgr,
+ const DataTransferManager& data_transfer_mgr,
std::unique_ptr& op_kernel) const;
// Check if an execution provider can create kernel for a node and return
diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h
index f38e6858847ee..e377f0d4e4239 100644
--- a/include/onnxruntime/core/framework/op_kernel_info.h
+++ b/include/onnxruntime/core/framework/op_kernel_info.h
@@ -15,16 +15,20 @@ namespace onnxruntime {
class OrtValueNameIdxMap;
class FuncManager;
+class DataTransferManager;
// A very light-weight class, which works as an aggregated
// view of all data needed for constructing a Kernel instance.
// NOTE: it does not own/hold any objects.
class OpKernelInfo : public OpNodeProtoHelper {
public:
- explicit OpKernelInfo(const onnxruntime::Node& node, const KernelDef& kernel_def,
+ explicit OpKernelInfo(const onnxruntime::Node& node,
+ const KernelDef& kernel_def,
const IExecutionProvider& execution_provider,
- const std::unordered_map& initialized_tensors,
- const OrtValueNameIdxMap& mlvalue_name_idx_map, const FuncManager& funcs_mgr);
+ const std::unordered_map& constant_initialized_tensors,
+ const OrtValueNameIdxMap& mlvalue_name_idx_map,
+ const FuncManager& funcs_mgr,
+ const DataTransferManager& data_transfer_mgr);
OpKernelInfo(const OpKernelInfo& other);
@@ -36,6 +40,8 @@ class OpKernelInfo : public OpNodeProtoHelper {
const IExecutionProvider* GetExecutionProvider() const noexcept;
+ const DataTransferManager& GetDataTransferManager() const noexcept;
+
const onnxruntime::Node& node() const noexcept;
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
@@ -53,9 +59,10 @@ class OpKernelInfo : public OpNodeProtoHelper {
// For non cpu/cuda case, this pointer should be set so that function kernel
// will delegate kernel compute call to compute call.
gsl::not_null execution_provider_;
- const std::unordered_map& initialized_tensors_;
+ const std::unordered_map& constant_initialized_tensors_;
const OrtValueNameIdxMap& ort_value_name_idx_map_;
const FuncManager& funcs_mgr_;
+ const DataTransferManager& data_transfer_mgr_;
ProtoHelperNodeContext proto_helper_context_;
};
diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h
index 52285311e5254..b66607853856a 100644
--- a/include/onnxruntime/core/framework/run_options.h
+++ b/include/onnxruntime/core/framework/run_options.h
@@ -14,8 +14,8 @@ struct OrtRunOptions {
/// Log severity. See https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/common/logging/severity.h
/// Default = -1 (use the log severity from the InferenceSession that the Run is for).
int run_log_severity_level = -1;
- unsigned run_log_verbosity_level = 0; ///< VLOG level if debug build and run_log_severity_level is 0 (VERBOSE).
- std::string run_tag; ///< A tag for the Run() calls using this.
+ int run_log_verbosity_level = 0; ///< VLOG level if debug build and run_log_severity_level is 0 (VERBOSE).
+ std::string run_tag; ///< A tag for the Run() calls using this.
// Set to 'true' to ensure the termination of all the outstanding Run() calls
// that use this OrtRunOptions instance. Some of the outstanding Run() calls may
diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h
index 260d1731bc6c0..35eb359c714a3 100644
--- a/include/onnxruntime/core/framework/tensor.h
+++ b/include/onnxruntime/core/framework/tensor.h
@@ -170,7 +170,7 @@ class Tensor final {
/**
The number of bytes of data.
*/
- size_t Size() const {
+ size_t SizeInBytes() const {
size_t ret;
int64_t l = shape_.Size();
if (l >= static_cast(std::numeric_limits::max())) {
diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h
index 5cf9cf08e0868..acf39638fe0db 100644
--- a/include/onnxruntime/core/framework/tensor_shape.h
+++ b/include/onnxruntime/core/framework/tensor_shape.h
@@ -37,6 +37,7 @@ class TensorShape : private std::vector {
TensorShape(const int64_t* dimension_sizes, size_t dimension_count);
TensorShape(const std::vector& dims);
+ TensorShape(std::vector&& dims);
TensorShape(const std::initializer_list& dims);
diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h
index 639ff301ff08f..5872228f383d2 100644
--- a/include/onnxruntime/core/graph/constants.h
+++ b/include/onnxruntime/core/graph/constants.h
@@ -18,6 +18,7 @@ constexpr const char* kOnnxDomain = "";
constexpr const char* kOnnxDomainAlias = "ai.onnx";
constexpr const char* kMLDomain = "ai.onnx.ml";
constexpr const char* kMSDomain = "com.microsoft";
+constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc";
constexpr const char* kNGraphDomain = "com.intel.ai";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
@@ -27,5 +28,6 @@ constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider";
constexpr const char* kBrainSliceExecutionProvider = "BrainSliceExecutionProvider";
constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider";
+constexpr const char* kNnapiExecutionProvider = "NnapiExecutionProvider";
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index 66b5954cf5177..b626a7541713f 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -279,6 +279,10 @@ class Node {
return !attr_to_subgraph_map_.empty();
}
+ /** Get the const subgraphs from a node.
+ @remarks Creates a new vector so calling ContainsSubgraphs first is preferred. */
+ std::vector> GetSubgraphs() const;
+
/** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
@returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
nullptr if the Node has no subgraphs.
@@ -500,6 +504,9 @@ class Graph {
/** Removes all initializer tensors from this Graph and releases the memory they were using. */
void CleanAllInitializedTensors() noexcept;
+ /** Returns true if an initializer value can be overridden by a graph input with the same name. */
+ bool CanOverrideInitializer() const noexcept { return ir_version_ >= 4; }
+
/** Gets the Graph inputs excluding initializers.
These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
@remarks Contains no nullptr values. */
@@ -750,6 +757,12 @@ class Graph {
/** Returns true if this is a subgraph or fase if it is a high-level graph. */
bool IsSubgraph() const { return parent_graph_ != nullptr; }
+ /** Returns the parent graph if this is a subgraph */
+ const Graph* ParentGraph() const { return parent_graph_; }
+
+ /** Returns the mutable parent graph if this is a subgraph */
+ Graph* MutableParentGraph() { return parent_graph_; }
+
/** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node.
Inherits some properties from the parent graph.
@param parent_graph The Graph containing the Node which has a GraphProto attribute.
@@ -840,7 +853,7 @@ class Graph {
// Build and verify node connection (edges).
// Verify NodeArg name/type/shape matching correctly.
- common::Status BuildConnections(std::vector& outer_scope_node_args_consumed);
+ common::Status BuildConnections(std::unordered_set& outer_scope_node_args_consumed);
common::Status VerifyNoDuplicateName();
@@ -962,7 +975,7 @@ class Graph {
std::unordered_map model_functions_;
// Model IR version.
- Version ir_version_{};
+ Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
int name_generator_ = 0;
@@ -974,6 +987,9 @@ class Graph {
// NodeArgs that come from outer scope. Used when building a graph so that
// these don't get recorded as graph inputs in the GraphProto.
std::unordered_set outer_scope_node_arg_names_;
+
+ // number of times Resolve has run.
+ int num_resolves_ = 0;
};
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_level.h b/include/onnxruntime/core/optimizer/graph_transformer_level.h
index ad7d71096ef69..4f2d5b305ce1d 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_level.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_level.h
@@ -7,11 +7,12 @@
namespace onnxruntime {
-enum class TransformerLevel : uint32_t {
+enum class TransformerLevel : int {
Default = 0,
Level1,
Level2,
- // Convenience enum to always get the max available value.
+ Level3,
+ // Convenience enum to always get the max available value.
// This way when we add more levels code which iterates over this enum does not need to change.
MaxTransformerLevel
};
diff --git a/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h b/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h
index 360de99b5cf62..66f258922c1f4 100644
--- a/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h
+++ b/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h
index 3fc4b7b51f4f3..81b5477b3cb4d 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h b/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h
index 03ef1158eeef5..a54b522d9e79f 100644
--- a/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h
+++ b/include/onnxruntime/core/providers/mkldnn/mkldnn_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h b/include/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h
index 0970362a2b557..87d98cdbdd34a 100644
--- a/include/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h
+++ b/include/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright(C) 2019 Intel Corporation
// Licensed under the MIT License
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
new file mode 100644
index 0000000000000..d8b6a1ec27634
--- /dev/null
+++ b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
@@ -0,0 +1,15 @@
+// Copyright 2019 JD.com Inc. JD AI
+
+#include "onnxruntime_c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options);
+
+#ifdef __cplusplus
+}
+#endif
+
+
diff --git a/include/onnxruntime/core/providers/openvino/openvino_provider_factory.h b/include/onnxruntime/core/providers/openvino/openvino_provider_factory.h
index eadcd45603762..08200319c71a2 100644
--- a/include/onnxruntime/core/providers/openvino/openvino_provider_factory.h
+++ b/include/onnxruntime/core/providers/openvino/openvino_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright(C) 2019 Intel Corporation
// Licensed under the MIT License
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
index bee1ae1b0939c..fb077fc5ff41d 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "core/session/onnxruntime_c_api.h"
+#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 03bf3a4467df3..6848fc31e453c 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -18,6 +18,7 @@ extern "C" {
#define _In_
#define _In_opt_
#define _Out_
+#define _Outptr_
#define _Out_opt_
#define _Inout_
#define _Inout_opt_
@@ -58,7 +59,6 @@ extern "C" {
#ifdef __cplusplus
// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation
-// Every type name starting with 'P' is a pointer type, an opaque handler
// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that.
// for ReleaseXXX(...) functions, they can accept NULL pointer.
#define NO_EXCEPTION noexcept
@@ -152,6 +152,7 @@ ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
ORT_RUNTIME_CLASS(SessionOptions);
ORT_RUNTIME_CLASS(Callback);
ORT_RUNTIME_CLASS(CustomOpDomain);
+ORT_RUNTIME_CLASS(Allocator);
// When passing in an allocator to any ORT function, be sure that the allocator object
// is not destroyed until the last allocated object using it is freed.
@@ -169,76 +170,76 @@ typedef void(ORT_API_CALL* OrtLoggingFunction)(
/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
-ORT_API_STATUS(OrtCreateEnv, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Out_ OrtEnv** out)
+ORT_API_STATUS(OrtCreateEnv, OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out)
ORT_ALL_ARGS_NONNULL;
/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
ORT_API_STATUS(OrtCreateEnvWithCustomLogger, OrtLoggingFunction logging_function,
- _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
+ _In_opt_ void* logger_param, OrtLoggingLevel default_logging_level,
_In_ const char* logid,
- _Out_ OrtEnv** out);
+ _Outptr_ OrtEnv** out);
// TODO: document the path separator convention? '/' vs '\'
// TODO: should specify the access characteristics of model_path. Is this read only during the
// execution of OrtCreateSession, or does the OrtSession retain a handle to the file/directory
// and continue to access throughout the OrtSession lifetime?
// What sort of access is needed to model_path : read or read/write?
-ORT_API_STATUS(OrtCreateSession, _In_ OrtEnv* env, _In_ const ORTCHAR_T* model_path,
- _In_ const OrtSessionOptions* options, _Out_ OrtSession** out);
+ORT_API_STATUS(OrtCreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
+ _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);
-ORT_API_STATUS(OrtCreateSessionFromArray, _In_ OrtEnv* env, _In_ const void* model_data, size_t model_data_length,
- _In_ const OrtSessionOptions* options, _Out_ OrtSession** out);
+ORT_API_STATUS(OrtCreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length,
+ _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);
ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
- _In_ const OrtRunOptions* run_options,
+ _In_opt_ const OrtRunOptions* run_options,
_In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len,
- _In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);
+ _In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** out);
/**
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
-ORT_API_STATUS(OrtCreateSessionOptions, _Out_ OrtSessionOptions** output);
+ORT_API_STATUS(OrtCreateSessionOptions, _Outptr_ OrtSessionOptions** options);
// create a copy of an existing OrtSessionOptions
-ORT_API_STATUS(OrtCloneSessionOptions, _In_ OrtSessionOptions* in, _Out_ OrtSessionOptions** output);
-ORT_API_STATUS(OrtEnableSequentialExecution, _In_ OrtSessionOptions* options);
-ORT_API_STATUS(OrtDisableSequentialExecution, _In_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtCloneSessionOptions, _In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options);
+ORT_API_STATUS(OrtEnableSequentialExecution, _Inout_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtDisableSequentialExecution, _Inout_ OrtSessionOptions* options);
// Enable profiling for this session.
-ORT_API_STATUS(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
-ORT_API_STATUS(OrtDisableProfiling, _In_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtEnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
+ORT_API_STATUS(OrtDisableProfiling, _Inout_ OrtSessionOptions* options);
// Enable the memory pattern optimization.
// The idea is if the input shapes are the same, we could trace the internal memory allocation
// and generate a memory pattern for future request. So next time we could just do one allocation
// with a big chunk for all the internal memory allocation.
// Note: memory pattern optimization is only available when SequentialExecution enabled.
-ORT_API_STATUS(OrtEnableMemPattern, _In_ OrtSessionOptions* options);
-ORT_API_STATUS(OrtDisableMemPattern, _In_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtEnableMemPattern, _Inout_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtDisableMemPattern, _Inout_ OrtSessionOptions* options);
// Enable the memory arena on CPU
// Arena may pre-allocate memory for future usage.
// set this option to false if you don't want it.
-ORT_API_STATUS(OrtEnableCpuMemArena, _In_ OrtSessionOptions* options);
-ORT_API_STATUS(OrtDisableCpuMemArena, _In_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtEnableCpuMemArena, _Inout_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtDisableCpuMemArena, _Inout_ OrtSessionOptions* options);
// < logger id to use for session output
-ORT_API_STATUS(OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
+ORT_API_STATUS(OrtSetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid);
// < applies to session load, initialization, etc
-ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level);
+ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level);
// Set Graph optimization level.
// Available options are : 0, 1, 2.
// 0 -> Disable all optimizations
// 1 -> Enable basic optimizations
// 2 -> Enable all optimizations
-ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level);
+ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, int graph_optimization_level);
// How many threads in the session thread pool.
-ORT_API_STATUS(OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
+ORT_API_STATUS(OrtSetSessionThreadPoolSize, _Inout_ OrtSessionOptions* options, int session_thread_pool_size);
/**
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
@@ -257,35 +258,36 @@ ORT_API_STATUS(OrtSessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size
/**
* \param out should be freed by OrtReleaseTypeInfo after use
*/
-ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
+ORT_API_STATUS(OrtSessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info);
/**
* \param out should be freed by OrtReleaseTypeInfo after use
*/
-ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Out_ OrtTypeInfo** out);
+ORT_API_STATUS(OrtSessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info);
/**
* \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible in freeing it.
*/
ORT_API_STATUS(OrtSessionGetInputName, _In_ const OrtSession* sess, size_t index,
- _Inout_ OrtAllocator* allocator, _Out_ char** value);
+ _Inout_ OrtAllocator* allocator, _Outptr_ char** value);
ORT_API_STATUS(OrtSessionGetOutputName, _In_ const OrtSession* sess, size_t index,
- _Inout_ OrtAllocator* allocator, _Out_ char** value);
+ _Inout_ OrtAllocator* allocator, _Outptr_ char** value);
/**
* \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use
*/
-ORT_API_STATUS(OrtCreateRunOptions, _Out_ OrtRunOptions** out);
+ORT_API_STATUS(OrtCreateRunOptions, _Outptr_ OrtRunOptions** out);
-ORT_API_STATUS(OrtRunOptionsSetRunLogVerbosityLevel, _In_ OrtRunOptions*, unsigned int);
+ORT_API_STATUS(OrtRunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value);
ORT_API_STATUS(OrtRunOptionsSetRunTag, _In_ OrtRunOptions*, _In_ const char* run_tag);
-ORT_API_STATUS(OrtRunOptionsGetRunLogVerbosityLevel, _In_ OrtRunOptions*, _Out_ unsigned int* out);
-ORT_API_STATUS(OrtRunOptionsGetRunTag, _In_ OrtRunOptions*, _Out_ const char** out);
+ORT_API_STATUS(OrtRunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out);
+ORT_API_STATUS(OrtRunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const char** out);
// Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions
// will exit as soon as possible if the flag is true.
-ORT_API_STATUS(OrtRunOptionsSetTerminate, _In_ OrtRunOptions*, _In_ int flag);
+ORT_API_STATUS(OrtRunOptionsEnableTerminate, _Inout_ OrtRunOptions* options);
+ORT_API_STATUS(OrtRunOptionsDisableTerminate, _Inout_ OrtRunOptions* options);
/**
* Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value
@@ -294,7 +296,7 @@ ORT_API_STATUS(OrtRunOptionsSetTerminate, _In_ OrtRunOptions*, _In_ int flag);
*/
ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type,
- _Out_ OrtValue** out);
+ _Outptr_ OrtValue** out);
/**
* Create a tensor with user's buffer. You can fill the buffer either before calling this function or after.
@@ -303,11 +305,11 @@ ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
*/
ORT_API_STATUS(OrtCreateTensorWithDataAsOrtValue, _In_ const OrtAllocatorInfo* info,
_Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len,
- ONNXTensorElementDataType type, _Out_ OrtValue** out);
+ ONNXTensorElementDataType type, _Outptr_ OrtValue** out);
// This function doesn't work with string tensor
// this is a no-copy method whose pointer is only valid until the backing OrtValue is free'd.
-ORT_API_STATUS(OrtGetTensorMutableData, _Inout_ OrtValue* value, _Out_ void** out);
+ORT_API_STATUS(OrtGetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** out);
/**
* \Sets *out to 1 iff an OrtValue is a tensor, 0 otherwise
@@ -319,7 +321,7 @@ ORT_API_STATUS(OrtIsTensor, _In_ const OrtValue* value, _Out_ int* out);
* \param s each A string array. Each string in this array must be null terminated.
* \param s_len length of s
*/
-ORT_API_STATUS(OrtFillStringTensor, _In_ OrtValue* value, _In_ const char* const* s, size_t s_len);
+ORT_API_STATUS(OrtFillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len);
/**
* \param value A tensor created from OrtCreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
@@ -350,7 +352,7 @@ ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _Out_ void
*/
ORT_API_STATUS(OrtTensorProtoToOrtValue, _In_ const void* input, int input_len,
_In_opt_ const ORTCHAR_T* input_file_path, _Inout_ void* preallocated, size_t preallocated_size,
- _Out_ OrtValue** out, _Out_ OrtCallback** deleter);
+ _Outptr_ OrtValue** out, _Outptr_ OrtCallback** deleter);
/**
* f will be freed in this call
@@ -366,19 +368,19 @@ ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input
/**
* Don't free the 'out' value
*/
-ORT_API_STATUS(OrtCastTypeInfoToTensorInfo, _In_ OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out);
+ORT_API_STATUS(OrtCastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out);
/**
* Return OnnxType from OrtTypeInfo
*/
-ORT_API_STATUS(OrtOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ enum ONNXType* out);
+ORT_API_STATUS(OrtGetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ enum ONNXType* out);
/**
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
-ORT_API_STATUS(OrtCreateTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo** out);
+ORT_API_STATUS(OrtCreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out);
-ORT_API_STATUS(OrtSetTensorElementType, _In_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type);
+ORT_API_STATUS(OrtSetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type);
/**
* \param info Created from OrtCreateTensorTypeAndShapeInfo() function
@@ -405,14 +407,14 @@ ORT_API_STATUS(OrtGetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeIn
/**
* \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use
*/
-ORT_API_STATUS(OrtGetTensorTypeAndShape, _In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
+ORT_API_STATUS(OrtGetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
/**
* Get the type information of an OrtValue
* \param value
* \param out The returned value should be freed by OrtReleaseTypeInfo after use
*/
-ORT_API_STATUS(OrtGetTypeInfo, _In_ const OrtValue* value, OrtTypeInfo** out);
+ORT_API_STATUS(OrtGetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out);
ORT_API_STATUS(OrtGetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out);
@@ -432,12 +434,12 @@ typedef enum OrtMemType {
OrtMemTypeDefault = 0, // the default allocator for execution provider
} OrtMemType;
-ORT_API_STATUS(OrtCreateAllocatorInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out);
+ORT_API_STATUS(OrtCreateAllocatorInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtAllocatorInfo** out);
/**
* Convenience function for special case of OrtCreateAllocatorInfo, for the CPU allocator. Uses name = "Cpu" and id = 0.
*/
-ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out)
+ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtAllocatorInfo** out)
ORT_ALL_ARGS_NONNULL;
/**
@@ -450,17 +452,16 @@ ORT_ALL_ARGS_NONNULL;
/**
* Do not free the returned value
*/
-ORT_API_STATUS(OrtAllocatorInfoGetName, _In_ OrtAllocatorInfo* ptr, _Out_ const char** out);
-ORT_API_STATUS(OrtAllocatorInfoGetId, _In_ OrtAllocatorInfo* ptr, _Out_ int* out);
-ORT_API_STATUS(OrtAllocatorInfoGetMemType, _In_ OrtAllocatorInfo* ptr, _Out_ OrtMemType* out);
-ORT_API_STATUS(OrtAllocatorInfoGetType, _In_ OrtAllocatorInfo* ptr, _Out_ OrtAllocatorType* out);
+ORT_API_STATUS(OrtAllocatorInfoGetName, _In_ const OrtAllocatorInfo* ptr, _Out_ const char** out);
+ORT_API_STATUS(OrtAllocatorInfoGetId, _In_ const OrtAllocatorInfo* ptr, _Out_ int* out);
+ORT_API_STATUS(OrtAllocatorInfoGetMemType, _In_ const OrtAllocatorInfo* ptr, _Out_ OrtMemType* out);
+ORT_API_STATUS(OrtAllocatorInfoGetType, _In_ const OrtAllocatorInfo* ptr, _Out_ OrtAllocatorType* out);
-ORT_API_STATUS(OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Out_ void** out);
+ORT_API_STATUS(OrtAllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out);
ORT_API_STATUS(OrtAllocatorFree, _Inout_ OrtAllocator* ptr, void* p);
ORT_API_STATUS(OrtAllocatorGetInfo, _In_ const OrtAllocator* ptr, _Out_ const OrtAllocatorInfo** out);
-ORT_API_STATUS(OrtCreateDefaultAllocator, _Out_ OrtAllocator** out);
-ORT_API(void, OrtReleaseAllocator, _In_ OrtAllocator* allocator);
+ORT_API_STATUS(OrtCreateDefaultAllocator, _Outptr_ OrtAllocator** out);
ORT_API(const char*, OrtGetVersionString);
/**
@@ -509,13 +510,13 @@ ORT_ALL_ARGS_NONNULL;
* If input OrtValue represents a sequence, use index to retrieve the index'th element
* of the sequence.
*/
-ORT_API_STATUS(OrtGetValue, const OrtValue* value, int index, OrtAllocator* allocator, OrtValue** out);
+ORT_API_STATUS(OrtGetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out);
/**
* Returns 2 for type map and N for sequence where N is the number of elements
* in the sequence.
*/
-ORT_API_STATUS(OrtGetValueCount, const OrtValue* value, size_t* out);
+ORT_API_STATUS(OrtGetValueCount, _In_ const OrtValue* value, _Out_ size_t* out);
/**
* To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues
@@ -524,8 +525,8 @@ ORT_API_STATUS(OrtGetValueCount, const OrtValue* value, size_t* out);
* sequence. 'in' should be an arrary of N OrtValues.
* \value_type should be either map or sequence.
*/
-ORT_API_STATUS(OrtCreateValue, OrtValue** in, size_t num_values, enum ONNXType value_type,
- OrtValue** out);
+ORT_API_STATUS(OrtCreateValue, _In_ const OrtValue* const* in, size_t num_values, enum ONNXType value_type,
+ _Outptr_ OrtValue** out);
/*
* EXPERIMENTAL APIS - Subject to change. Released as a preview to get feedback and enable early testing
@@ -548,8 +549,9 @@ struct OrtCustomOpApi {
*/
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
+ OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_string)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size);
- OrtStatus*(ORT_API_CALL* GetTensorTypeAndShape)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
+ OrtStatus*(ORT_API_CALL* GetTensorTypeAndShape)(_In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
OrtStatus*(ORT_API_CALL* GetTensorShapeElementCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out);
@@ -557,14 +559,14 @@ struct OrtCustomOpApi {
OrtStatus*(ORT_API_CALL* GetDimensionCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
- OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** data);
+ OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** data);
- void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input);
+ void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(_In_ OrtTensorTypeAndShapeInfo* input);
- OrtStatus*(ORT_API_CALL* KernelContext_GetInputCount)(const OrtKernelContext* context, _Out_ size_t* out);
- OrtStatus*(ORT_API_CALL* KernelContext_GetInput)(const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out);
- OrtStatus*(ORT_API_CALL* KernelContext_GetOutputCount)(const OrtKernelContext* context, _Out_ size_t* out);
- OrtStatus*(ORT_API_CALL* KernelContext_GetOutput)(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out);
+ OrtStatus*(ORT_API_CALL* KernelContext_GetInputCount)(_In_ const OrtKernelContext* context, _Out_ size_t* out);
+ OrtStatus*(ORT_API_CALL* KernelContext_GetInput)(_In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out);
+ OrtStatus*(ORT_API_CALL* KernelContext_GetOutputCount)(_In_ const OrtKernelContext* context, _Out_ size_t* out);
+ OrtStatus*(ORT_API_CALL* KernelContext_GetOutput)(_Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out);
};
typedef struct OrtCustomOpApi OrtCustomOpApi;
@@ -599,19 +601,19 @@ typedef struct OrtCustomOp OrtCustomOp;
/*
* Create a custom op domain. After all sessions using it are released, call OrtReleaseCustomOpDomain
*/
-ORT_API_STATUS(OrtCreateCustomOpDomain, _In_ const char* domain, _Out_ OrtCustomOpDomain** out);
+ORT_API_STATUS(OrtCreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out);
/*
* Add custom ops to the OrtCustomOpDomain
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released
*/
-ORT_API_STATUS(OrtCustomOpDomain_Add, _In_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op);
+ORT_API_STATUS(OrtCustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op);
/*
* Add a custom op domain to the OrtSessionOptions
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released
*/
-ORT_API_STATUS(OrtAddCustomOpDomain, _In_ OrtSessionOptions* options, OrtCustomOpDomain* custom_op_domain);
+ORT_API_STATUS(OrtAddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain);
/*
* END EXPERIMENTAL
*/
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index df15d2d2ecde6..e21e87596781e 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -73,8 +73,8 @@ struct Base {
protected:
Base(const Base&) = delete;
- Base(Base&& v) : p_{v.p_} { v.p_ = nullptr; }
- void operator=(Base&& v) {
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
+ void operator=(Base&& v) noexcept {
OrtRelease(p_);
p_ = v.p_;
v.p_ = nullptr;
@@ -101,8 +101,8 @@ struct Value;
struct Env : Base {
Env(nullptr_t) {}
- Env(OrtLoggingLevel default_warning_level, _In_ const char* logid);
- Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
+ Env(OrtLoggingLevel default_logging_level, _In_ const char* logid);
+ Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
explicit Env(OrtEnv* p) : Base{p} {}
};
@@ -117,13 +117,14 @@ struct RunOptions : Base {
RunOptions(nullptr_t) {}
RunOptions();
- RunOptions& SetRunLogVerbosityLevel(unsigned int);
- unsigned int GetRunLogVerbosityLevel() const;
+ RunOptions& SetRunLogVerbosityLevel(int);
+ int GetRunLogVerbosityLevel() const;
RunOptions& SetRunTag(const char* run_tag);
const char* GetRunTag() const;
- RunOptions& SetTerminate(bool flag);
+ RunOptions& EnableTerminate();
+ RunOptions& DisableTerminate();
};
struct SessionOptions : Base {
@@ -134,7 +135,7 @@ struct SessionOptions : Base {
SessionOptions Clone() const;
SessionOptions& SetThreadPoolSize(int session_thread_pool_size);
- SessionOptions& SetGraphOptimizationLevel(uint32_t graph_optimization_level);
+ SessionOptions& SetGraphOptimizationLevel(int graph_optimization_level);
SessionOptions& EnableCpuMemArena();
SessionOptions& DisableCpuMemArena();
@@ -252,7 +253,7 @@ struct AllocatorInfo : Base {
struct CustomOpApi {
CustomOpApi(const OrtCustomOpApi& api) : api_(api) {}
- template // T is only implemented for float and int64_t
+ template // T is only implemented for float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 970155aeaa383..0fbbbde445b16 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+// Don't include this file directly. Please include "onnxruntime_cxx_api.h" instead.
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
// the main C++ file with implementation details.
@@ -90,13 +91,13 @@ inline RunOptions::RunOptions() {
ORT_THROW_ON_ERROR(OrtCreateRunOptions(&p_));
}
-inline RunOptions& RunOptions::SetRunLogVerbosityLevel(unsigned int level) {
+inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
ORT_THROW_ON_ERROR(OrtRunOptionsSetRunLogVerbosityLevel(p_, level));
return *this;
}
-inline unsigned int RunOptions::GetRunLogVerbosityLevel() const {
- unsigned int out;
+inline int RunOptions::GetRunLogVerbosityLevel() const {
+ int out;
ORT_THROW_ON_ERROR(OrtRunOptionsGetRunLogVerbosityLevel(p_, &out));
return out;
}
@@ -112,8 +113,13 @@ inline const char* RunOptions::GetRunTag() const {
return out;
}
-inline RunOptions& RunOptions::SetTerminate(bool flag) {
- ORT_THROW_ON_ERROR(OrtRunOptionsSetTerminate(p_, flag ? 1 : 0));
+inline RunOptions& RunOptions::EnableTerminate() {
+ ORT_THROW_ON_ERROR(OrtRunOptionsEnableTerminate(p_));
+ return *this;
+}
+
+inline RunOptions& RunOptions::DisableTerminate() {
+ ORT_THROW_ON_ERROR(OrtRunOptionsDisableTerminate(p_));
return *this;
}
@@ -132,7 +138,7 @@ inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool
return *this;
}
-inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(uint32_t graph_optimization_level) {
+inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(int graph_optimization_level) {
ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level));
return *this;
}
@@ -283,7 +289,7 @@ inline Unowned TypeInfo::GetTensorTypeAndShapeInfo() con
inline ONNXType TypeInfo::GetONNXType() const {
ONNXType out;
- ORT_THROW_ON_ERROR(OrtOnnxTypeFromTypeInfo(p_, &out));
+ ORT_THROW_ON_ERROR(OrtGetOnnxTypeFromTypeInfo(p_, &out));
return out;
}
@@ -393,6 +399,24 @@ inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernel
return out;
}
+template <>
+inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
+ size_t size = 0;
+ std::string out;
+ OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
+
+ // The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
+ if (OrtGetErrorCode(status) == ORT_INVALID_ARGUMENT) {
+ OrtReleaseStatus(status);
+ out.resize(size);
+ ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
+ out.resize(size - 1); // remove the terminating character '\0'
+ } else {
+ ORT_THROW_ON_ERROR(status);
+ }
+ return out;
+}
+
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
OrtTensorTypeAndShapeInfo* out;
ORT_THROW_ON_ERROR(api_.GetTensorTypeAndShape(value, &out));
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index ad73baf144239..29e8f5fb33ebf 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -12,7 +12,7 @@
as Deep Learning algorithms in the
`ONNX-ML format `_.
"""
-__version__ = "0.4.0"
+__version__ = "0.5.0"
__author__ = "Microsoft"
from onnxruntime.capi import onnxruntime_validation
diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc
index 23eb0cc8e1424..7f7102475c620 100644
--- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc
+++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc
@@ -228,77 +228,122 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
gsl::span last_cell_2 = last_cell.subspan(last_cell_size_per_direction,
last_cell_size_per_direction);
- auto fam = std::make_unique>(
- alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
- fam->SetWeights(
+ BahdanauAttention fam(
+ alloc,
+ logger,
+ batch_size,
+ max_memory_step,
+ memory_depth,
+ query_depth,
+ am_attn_size,
+ false);
+
+ fam.SetWeights(
FirstHalfSpan(am_v_weights.DataAsSpan()),
FirstHalfSpan(am_query_layer_weights.DataAsSpan()),
FirstHalfSpan(am_memory_layer_weights.DataAsSpan()));
- fam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
-
- auto faw = std::make_unique>(
- alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
- faw->SetWeights(FirstHalfSpan(attn_layer_weights_span));
-
- auto fw = std::make_unique>(
+ fam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
+
+ AttentionWrapper faw(
+ alloc,
+ logger,
+ batch_size,
+ memory_depth,
+ attn_layer_depth,
+ hidden_size_,
+ has_attention_layer,
+ fam);
+ faw.SetWeights(FirstHalfSpan(attn_layer_weights_span));
+
+ UniDirectionalAttnLstm fw(
alloc, logger,
seq_length, batch_size, input_size,
- hidden_size_, Direction::kForward, input_forget_, *faw,
+ hidden_size_, Direction::kForward, input_forget_, faw,
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
- auto bam = std::make_unique>(
- alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
- bam->SetWeights(
+ BahdanauAttention bam(
+ alloc,
+ logger,
+ batch_size,
+ max_memory_step,
+ memory_depth,
+ query_depth,
+ am_attn_size,
+ false);
+ bam.SetWeights(
SecondHalfSpan(am_v_weights.DataAsSpan()),
SecondHalfSpan(am_query_layer_weights.DataAsSpan()),
SecondHalfSpan(am_memory_layer_weights.DataAsSpan()));
- bam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
-
- auto baw = std::make_unique>(
- alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *bam);
- baw->SetWeights(SecondHalfSpan(attn_layer_weights_span));
-
- auto bw = std::make_unique>(
+ bam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
+
+ AttentionWrapper baw(
+ alloc,
+ logger,
+ batch_size,
+ memory_depth,
+ attn_layer_depth,
+ hidden_size_,
+ has_attention_layer,
+ bam);
+ baw.SetWeights(SecondHalfSpan(attn_layer_weights_span));
+
+ UniDirectionalAttnLstm bw(
alloc, logger,
seq_length, batch_size, input_size,
- hidden_size_, Direction::kReverse, input_forget_, *baw,
+ hidden_size_, Direction::kReverse, input_forget_, baw,
bias_2, peephole_weights_2, initial_hidden_2, initial_cell_2,
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, ttp_);
- fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
- bw->Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
+ fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
+ bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
} else {
- auto fam = std::make_unique>(
- alloc, logger, batch_size, max_memory_step, memory_depth, query_depth, am_attn_size, false);
- fam->SetWeights(
+ BahdanauAttention fam(
+ alloc,
+ logger,
+ batch_size,
+ max_memory_step,
+ memory_depth,
+ query_depth,
+ am_attn_size,
+ false);
+
+ fam.SetWeights(
am_v_weights.DataAsSpan(),
am_query_layer_weights.DataAsSpan(),
am_memory_layer_weights.DataAsSpan());
- fam->PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
+ fam.PrepareMemory(attn_memory.DataAsSpan(), memory_seq_lens_span);
+
+ AttentionWrapper faw(
+ alloc,
+ logger,
+ batch_size,
+ memory_depth,
+ attn_layer_depth,
+ hidden_size_,
+ has_attention_layer,
+ fam);
- auto faw = std::make_unique>(
- alloc, logger, batch_size, memory_depth, attn_layer_depth, hidden_size_, has_attention_layer, *fam);
- faw->SetWeights(attn_layer_weights_span);
+ faw.SetWeights(attn_layer_weights_span);
- auto fw = std::make_unique>(
+ UniDirectionalAttnLstm fw(
alloc, logger,
seq_length, batch_size, input_size,
- hidden_size_, direction_, input_forget_, *faw,
+ hidden_size_, direction_, input_forget_, faw,
bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, ttp_);
- fw->Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
+ fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
}
if (!output.empty()) {
diff --git a/onnxruntime/contrib_ops/cpu/fused_activation.cc b/onnxruntime/contrib_ops/cpu/fused_activation.cc
new file mode 100644
index 0000000000000..d63e19991e754
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/fused_activation.cc
@@ -0,0 +1,49 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/fused_activation.h"
+
+namespace onnxruntime {
+
+common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) {
+ // Convert the activation parameters from the node into a MLAS_ACTIVATION.
+ activation.ActivationKind = MlasIdentityActivation;
+
+ std::string activation_type;
+ if (info.GetAttr("activation", &activation_type).IsOK()) {
+ if (activation_type == "Relu") {
+ activation.ActivationKind = MlasReluActivation;
+ } else if (activation_type == "Tanh") {
+ activation.ActivationKind = MlasTanhActivation;
+ } else if (activation_type == "Sigmoid") {
+ activation.ActivationKind = MlasLogisticActivation;
+ } else {
+ // The remaining activation types have additional parameters to be pulled out.
+ size_t activation_params_count;
+ if (activation_type == "LeakyRelu") {
+ activation.ActivationKind = MlasLeakyReluActivation;
+ activation_params_count = 1;
+ } else if (activation_type == "Clip") {
+ activation.ActivationKind = MlasClipActivation;
+ activation_params_count = 2;
+ } else {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type);
+ }
+
+ std::vector activation_params;
+ common::Status status = info.GetAttrs("activation_params", activation_params);
+ if (!status.IsOK()) {
+ return status;
+ } else if (activation_params_count != activation_params.size()) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch");
+ }
+ for (size_t i = 0; i < activation_params_count; i++) {
+ activation.Parameters.Values[i] = activation_params[i];
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/fused_activation.h b/onnxruntime/contrib_ops/cpu/fused_activation.h
new file mode 100644
index 0000000000000..0121a2038e1cb
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/fused_activation.h
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "core/util/math.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+
+common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation);
+
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.cc b/onnxruntime/contrib_ops/cpu/fused_conv.cc
index ae8f81e8129ce..2e07fa27d7cbb 100644
--- a/onnxruntime/contrib_ops/cpu/fused_conv.cc
+++ b/onnxruntime/contrib_ops/cpu/fused_conv.cc
@@ -1,16 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "fused_conv.h"
+#include "core/providers/cpu/nn/conv.h"
+#include "contrib_ops/cpu/fused_activation.h"
namespace onnxruntime {
namespace contrib {
+
+class FusedConvFloat final : public Conv {
+ public:
+ FusedConvFloat(const OpKernelInfo& info) : Conv(info) {
+ ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
+ }
+};
+
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedConv,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType()),
- FusedConv);
+ FusedConvFloat);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.h b/onnxruntime/contrib_ops/cpu/fused_conv.h
deleted file mode 100644
index 329eb82990838..0000000000000
--- a/onnxruntime/contrib_ops/cpu/fused_conv.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/providers/cpu/nn/conv_impl.h"
-
-namespace onnxruntime {
-namespace contrib {
-
-template
-class FusedConv : public Conv {
- public:
- FusedConv(const OpKernelInfo& info) : Conv(info) {
- Conv::activation_ = info.GetAttrOrDefault("activation", "");
- Conv::alpha_ = info.GetAttrOrDefault("alpha", 0.01f);
- }
-
- Status Compute(OpKernelContext* context) const override {
- return Conv::Compute(context);
- }
-};
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.cc b/onnxruntime/contrib_ops/cpu/fused_gemm.cc
index e3bfe5b3881ce..d743a3fcad7be 100644
--- a/onnxruntime/contrib_ops/cpu/fused_gemm.cc
+++ b/onnxruntime/contrib_ops/cpu/fused_gemm.cc
@@ -1,15 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "fused_gemm.h"
+#include "core/providers/cpu/math/gemm.h"
namespace onnxruntime {
namespace contrib {
+
+template
+class FusedGemm final : public Gemm {
+ public:
+ FusedGemm(const OpKernelInfo& info) : Gemm(info) {
+ Gemm::activation_ = info.GetAttrOrDefault("activation", "");
+ Gemm::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
+ }
+};
+
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
- FusedGemm);
+ FusedGemm);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.h b/onnxruntime/contrib_ops/cpu/fused_gemm.h
deleted file mode 100644
index 5be1b34cb41c4..0000000000000
--- a/onnxruntime/contrib_ops/cpu/fused_gemm.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/providers/cpu/math/gemm.h"
-
-namespace onnxruntime {
-namespace contrib {
-template
-class FusedGemm : public Gemm {
- public:
- FusedGemm(const OpKernelInfo& info) : Gemm(info) {
- Gemm::activation_ = info.GetAttrOrDefault("activation", "");
- Gemm::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
- }
-
- Status Compute(OpKernelContext* context) const override {
- return Gemm::Compute(context);
- }
-};
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
new file mode 100644
index 0000000000000..b5625551ad104
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
@@ -0,0 +1,205 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/framework/op_kernel_context_internal.h"
+#include "nchwc_ops.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ ReorderInput,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ ReorderInput);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ ReorderOutput,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ ReorderOutput);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ Conv,
+ 1,
+ float,
+ KernelDefBuilder()
+ .MayInplace(3, 0)
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcConv);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ MaxPool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcMaxPool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ GlobalMaxPool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcMaxPool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ AveragePool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcAveragePool);
+
+ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
+ GlobalAveragePool,
+ 1,
+ float,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ NchwcAveragePool);
+
+template
+Status ReorderInput::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ const auto& X_shape = X->Shape();
+ ORT_ENFORCE(X_shape.NumDimensions() == 4);
+ ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
+ auto* Y = context->Output(0, X_shape);
+ MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData());
+ return Status::OK();
+}
+
+template
+Status ReorderOutput::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ const auto& X_shape = X->Shape();
+ ORT_ENFORCE(X_shape.NumDimensions() == 4);
+ std::vector Y_shape(X_shape.GetDims());
+ ORT_ENFORCE(channels_ <= Y_shape[1]);
+ Y_shape[1] = channels_;
+ auto* Y = context->Output(0, Y_shape);
+ MlasReorderOutput(Y_shape.data(), X->template Data(), Y->template MutableData());
+ return Status::OK();
+}
+
+Status NchwcConv::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ const auto* W = context->Input(1);
+ const auto* B = context->Input(2);
+ const auto* Sum = context->Input(3);
+
+ ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));
+
+ const auto& X_shape = X->Shape();
+ const auto& W_shape = W->Shape();
+ ORT_ENFORCE(X_shape.NumDimensions() == 4);
+
+ const size_t nchwc_block_size = MlasNchwcGetBlockSize();
+ ORT_ENFORCE((static_cast(X_shape[1]) < nchwc_block_size) || ((X_shape[1] % nchwc_block_size) == 0));
+
+ std::vector kernel_shape;
+ ORT_RETURN_IF_ERROR(ConvBase::ComputeKernelShape(W_shape, kernel_shape));
+ if (kernel_shape.size() != 2) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported convolution size.");
+ }
+
+ std::vector pads(ConvBase::pads_);
+ if (pads.empty()) {
+ pads.resize(kernel_shape.size() * 2, 0);
+ }
+ std::vector dilations(ConvBase::dilations_);
+ if (dilations.empty()) {
+ dilations.resize(kernel_shape.size(), 1);
+ }
+ std::vector strides(ConvBase::strides_);
+ if (strides.empty()) {
+ strides.resize(kernel_shape.size(), 1);
+ }
+
+ std::vector Y_dims;
+ Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
+ TensorShape input_shape = X->Shape().Slice(2);
+ ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
+ auto* Y = context->Output(0, Y_dims);
+ auto* y_data = Y->template MutableData();
+
+ // Check for the optional Conv/Sum fusion.
+ if (Sum != nullptr) {
+ const auto& sum_shape = Sum->Shape();
+ ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
+ // If the output was not allocated inplace with the sum tensor, then copy here.
+ const auto* sum_data = Sum->template Data();
+ if (y_data != sum_data) {
+ memcpy(y_data, sum_data, sum_shape.Size() * sizeof(float));
+ }
+ }
+
+ MlasNchwcConv(kernel_shape.size(),
+ X_shape.GetDims().data(),
+ kernel_shape.data(),
+ dilations.data(),
+ pads.data(),
+ strides.data(),
+ Y_dims.data(),
+ static_cast(ConvBase::group_),
+ X->template Data(),
+ W->template Data(),
+ B != nullptr ? B->template Data() : nullptr,
+ y_data,
+ &activation_,
+ Sum == nullptr,
+ const_cast(static_cast(context)->GetOperatorThreadPool()));
+
+ return Status::OK();
+}
+
+Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const {
+ const auto* X = context->Input(0);
+
+ const auto& X_shape = X->Shape();
+ ORT_ENFORCE(X_shape.NumDimensions() == 4);
+ ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
+
+ if (!global_pooling_) {
+ ORT_RETURN_IF_NOT(kernel_shape_.size() == 2, "kernel_shape num_dims is not compatible with X num_dims.");
+ }
+
+ std::vector pads = pads_;
+ std::vector output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_);
+ auto* Y = context->Output(0, output_dims);
+
+ MlasNchwcPool(kind,
+ 2,
+ X_shape.GetDims().data(),
+ global_pooling_ ? nullptr : kernel_shape_.data(),
+ global_pooling_ ? nullptr : dilations_.data(),
+ global_pooling_ ? nullptr : pads.data(),
+ global_pooling_ ? nullptr : strides_.data(),
+ output_dims.data(),
+ X->template Data(),
+ Y->template MutableData(),
+ const_cast(static_cast(context)->GetOperatorThreadPool()));
+
+ return Status::OK();
+}
+
+Status NchwcMaxPool::Compute(OpKernelContext* context) const {
+ return NchwcPoolBase::NchwcPool(context, MlasMaximumPooling);
+}
+
+Status NchwcAveragePool::Compute(OpKernelContext* context) const {
+ return NchwcPoolBase::NchwcPool(context, count_include_pad_ ? MlasAveragePoolingIncludePad : MlasAveragePoolingExcludePad);
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
new file mode 100644
index 0000000000000..65045cd0eeb85
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/nn/conv_base.h"
+#include "core/providers/cpu/nn/pool.h"
+#include "contrib_ops/cpu/fused_activation.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class ReorderInput : public OpKernel {
+ public:
+ ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+};
+
+template
+class ReorderOutput : public OpKernel {
+ public:
+ ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(info.GetAttr("channels", &channels_).IsOK());
+ ORT_ENFORCE(channels_ > 0, "invalid channel count");
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t channels_;
+};
+
+class NchwcConv : public OpKernel, public ConvBase {
+ public:
+ NchwcConv(const OpKernelInfo& info) : OpKernel(info), ConvBase(info) {
+ ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ MLAS_ACTIVATION activation_;
+};
+
+class NchwcPoolBase : public PoolBase {
+ public:
+ NchwcPoolBase(const OpKernelInfo& info) : PoolBase(info) {
+ }
+
+ Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const;
+};
+
+class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
+ public:
+ NchwcMaxPool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+};
+
+class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
+ public:
+ NchwcAveragePool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc
index e9994011aa039..8446a35bd8947 100644
--- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc
@@ -3,6 +3,7 @@
#include "contrib_ops/cpu_contrib_kernels.h"
#include "core/graph/constants.h"
+#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
namespace contrib {
@@ -49,6 +50,29 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Sca
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, Conv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, MaxPool);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, GlobalMaxPool);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, AveragePool);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, GlobalAveragePool);
+
+void RegisterNchwcKernels(KernelRegistry& kernel_registry) {
+ static const BuildKernelCreateInfoFn function_table[] = {
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo};
+
+ for (auto& function_table_entry : function_table) {
+ kernel_registry.Register(function_table_entry());
+ }
+}
+
void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo,
@@ -96,6 +120,12 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
}
+
+ // Register the NCHWc kernels if supported by the platform.
+ if (MlasNchwcGetBlockSize() > 1) {
+ RegisterNchwcKernels(kernel_registry);
+ }
}
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/codegen/common/common.cc b/onnxruntime/core/codegen/common/common.cc
new file mode 100644
index 0000000000000..757c1677dd2e5
--- /dev/null
+++ b/onnxruntime/core/codegen/common/common.cc
@@ -0,0 +1,258 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/codegen/common/common.h"
+
+#include "core/framework/tensorprotoutils.h"
+#include "core/graph/graph.h"
+#include "core/graph/schema_registry.h"
+#include
+#include
+
+namespace onnxruntime {
+
+NodeKey GetKey(const onnxruntime::Node* node) {
+ ORT_ENFORCE(nullptr != node);
+ ORT_ENFORCE(node->OutputDefs().size() > 0);
+ return node->OutputDefs()[0]->Name();
+}
+
+NodeKey GetKey(const onnxruntime::Node& node) {
+ ORT_ENFORCE(node.OutputDefs().size() > 0);
+ return node.OutputDefs()[0]->Name();
+}
+
+NodeKey GetKey(const onnxruntime::NodeArg* def) {
+ // NodeArg's name is unique.
+ ORT_ENFORCE(nullptr != def);
+ return def->Name();
+}
+
+bool IsRecurrentNode(const onnxruntime::Node& node) {
+ auto op_type = node.OpType();
+ return (op_type == "LSTM" || op_type == "RNN" || op_type == "GRU" ||
+ op_type == "Scan" || op_type == "Loop");
+}
+
+bool IsAliasNode(const onnxruntime::Node& node) {
+ auto op_type = node.OpType();
+ return (op_type == "Flatten" || op_type == "Identity" || op_type == "Reshape" ||
+ op_type == "Squeeze" || op_type == "Unsqueeze");
+}
+
+std::string NormalizeCppName(const std::string& name) {
+ std::string normalized_name = name;
+ for (char c : {'.', ' ', '+', '-', '*', '/', '\\', '='})
+ std::replace(normalized_name.begin(), normalized_name.end(), c, '_');
+ return normalized_name;
+}
+
+std::string NormalizeNodeArgName(const NodeArg* def) {
+ return NormalizeCppName(def->Name());
+}
+
+bool IsFusedNode(const Node& node) {
+ if (node.NodeType() == Node::Type::Fused) {
+ return true;
+ }
+ return false;
+}
+
+// A unified API to get Subgraph
+const Graph* GetSubgraph(const Node& node) {
+ if (node.NodeType() == Node::Type::Fused) {
+ return &(node.GetFunctionBody()->Body());
+ } else if (node.OpType() == "Scan") {
+ return node.GetGraphAttribute("body");
+ }
+ // return nullptr implying no subgraph
+ return nullptr;
+}
+
+bool HasLoop(const Node& node) {
+ auto op_type = node.OpType();
+ if (op_type == "LSTM" ||
+ op_type == "GRU" ||
+ op_type == "RNN" ||
+ op_type == "Scan") {
+ return true;
+ }
+ return false;
+}
+
+// Return the corresponding input node for the NodeArg of the given node
+const onnxruntime::Node* GetInputNode(const Node& node, const NodeArg* def) {
+ const auto& input_name = def->Name();
+ const onnxruntime::Node* input_node = nullptr;
+ // search input node set to see if input_name is in their outputs (weights are not from node)
+ for (auto iter = node.InputNodesBegin(); iter != node.InputNodesEnd(); ++iter) {
+ const onnxruntime::Node& p = *iter;
+ bool found = false;
+ p.ForEachWithIndex(
+ p.OutputDefs(),
+ [&found, &input_name](const onnxruntime::NodeArg& out_def, size_t) {
+ if (input_name == out_def.Name()) {
+ found = true;
+ }
+ return Status::OK();
+ });
+ if (found)
+ input_node = &p;
+ }
+ return input_node;
+}
+
+// create capacity from subgraph
+std::unique_ptr ToCapacity(const onnxruntime::GraphViewer& graph,
+ std::unique_ptr& subgraph) {
+ auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
+ static int fuse_count = 0;
+ meta_def->name = "Fuse" + std::to_string(fuse_count++);
+ meta_def->domain = "Fuse";
+
+ std::set node_indices(subgraph->nodes.begin(), subgraph->nodes.end());
+
+ const auto& start_node_index = subgraph->nodes.front();
+ const auto& start_node = *graph.GetNode(start_node_index);
+ const auto& end_node_index = subgraph->nodes.back();
+ const auto& end_node = *graph.GetNode(end_node_index);
+ meta_def->name += start_node.OpType() + std::to_string(start_node_index);
+ meta_def->name += "_With" + std::to_string(subgraph->nodes.size()) + "Nodes_";
+ meta_def->name += end_node.OpType() + std::to_string(end_node_index);
+
+ for (const auto& node_index : subgraph->nodes) {
+ const auto& node = *graph.GetNode(node_index);
+ // handle current graph's inputs
+ node.ForEachWithIndex(
+ node.InputDefs(),
+ [&meta_def, &node, &node_indices](const onnxruntime::NodeArg& def, size_t) {
+ const onnxruntime::Node* input_node = GetInputNode(node, &def);
+ bool input_from_subgraph = (input_node && node_indices.count(input_node->Index()));
+ if (!input_from_subgraph) {
+ // input is from weights or outside of graph
+ meta_def->inputs.push_back(def.Name());
+ }
+ return Status::OK();
+ });
+
+ // Handle outouts
+ // two cases are considerd as outputs
+ // 1. Output NodeArg is not used by any Node
+ // 2. Output NodeArg is used by at least one Node out of this subgraph.
+ // Note a NodeArg can be used by Nodes in and out of the subgraph at the same time.
+
+ auto InsertOutputToSubgraph = [&meta_def](const NodeArg* def) {
+ if (std::find(meta_def->outputs.begin(), meta_def->outputs.end(), def->Name()) ==
+ meta_def->outputs.end()) {
+ meta_def->outputs.push_back(def->Name());
+ }
+ };
+
+ std::unordered_set input_names_from_the_output_node;
+
+ for (auto o_iter = node.OutputEdgesBegin(); o_iter != node.OutputEdgesEnd(); ++o_iter) {
+ const auto& p = *o_iter;
+ const Node& out_node = p.GetNode();
+
+ // preprocess for the case 1
+ out_node.ForEachWithIndex(
+ out_node.InputDefs(),
+ [&input_names_from_the_output_node](const onnxruntime::NodeArg& in_def, size_t) {
+ input_names_from_the_output_node.insert(in_def.Name());
+ return Status::OK();
+ });
+
+ // handle the case 2
+ if (node_indices.count(out_node.Index()) == 0) {
+ const NodeArg* def = node.OutputDefs()[p.GetSrcArgIndex()];
+ InsertOutputToSubgraph(def);
+ }
+ }
+
+ // handle case 1
+ node.ForEachWithIndex(
+ node.OutputDefs(),
+ [&](const onnxruntime::NodeArg& def, size_t) {
+ if (input_names_from_the_output_node.count(def.Name()) == 0) {
+ InsertOutputToSubgraph(&def);
+ }
+ return Status::OK();
+ });
+ }
+
+ // Handle subgraph's initializers
+ const auto& all_initializers = graph.GetAllInitializedTensors();
+ for (const auto& node_index : subgraph->nodes) {
+ const auto& node = *graph.GetNode(node_index);
+ // check whether it is an immediate nested subgraph
+ auto immediate_nested_subgraph = GetSubgraph(node);
+ // If so, copy the immediate nested subgraph's initializers to meta_def->inputs.
+ // Note we don't need recursion here, since Ort did recursion for us by handling subgraph early than the current graph.
+ // Therefore, the all inner nested subgraph's initializers should be already in the immediate nested subgraph's inputs.
+ if (nullptr != immediate_nested_subgraph) {
+ for (auto& n : immediate_nested_subgraph->Nodes()) {
+ n.ForEachWithIndex(
+ n.InputDefs(),
+ [&meta_def, &all_initializers](const onnxruntime::NodeArg& def, size_t) {
+ auto iter = all_initializers.find(def.Name());
+ if (iter != all_initializers.end()) {
+ meta_def->inputs.push_back(def.Name());
+ }
+ return Status::OK();
+ });
+ }
+ }
+ }
+
+ meta_def->since_version = 1;
+ meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
+ std::unique_ptr finished_subgraph(subgraph.release());
+ finished_subgraph->SetMetaDef(meta_def);
+ return std::make_unique(std::move(finished_subgraph));
+}
+
+int64_t ShapeRank(const NodeArg* def) {
+ ORT_ENFORCE_DEBUG(nullptr != def);
+ return gsl::narrow_cast(def->Shape()->dim_size());
+}
+
+bool ShapeHasValue(const NodeArg* def, int i) {
+ ORT_ENFORCE_DEBUG(nullptr != def);
+ ORT_ENFORCE_DEBUG(i >= 0);
+ ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size());
+ return def->Shape()->dim(i).has_dim_value();
+}
+
+bool ShapeHasSymbol(const NodeArg* def, int i) {
+ ORT_ENFORCE_DEBUG(nullptr != def);
+ ORT_ENFORCE_DEBUG(i >= 0);
+ ORT_ENFORCE_DEBUG(i < def->Shape()->dim_size());
+ return def->Shape()->dim(i).has_dim_param();
+}
+
+int64_t ShapeValue(const NodeArg* def, int i) {
+ ORT_ENFORCE_DEBUG(ShapeHasValue(def, i));
+ return def->Shape()->dim(i).dim_value();
+}
+
+const std::string& ShapeSymbol(const NodeArg* def, int i) {
+ ORT_ENFORCE_DEBUG(ShapeHasSymbol(def, i));
+ return def->Shape()->dim(i).dim_param();
+}
+
+ONNX_NAMESPACE::TensorProto_DataType TensorProtoDataType(const NodeArg* def) {
+ ORT_ENFORCE_DEBUG(nullptr != def);
+ return static_cast(def->TypeAsProto()->tensor_type().elem_type());
+}
+
+// Convert GraphNodes to internal NodePtrs without check lifetime.
+// Please use it only locally when GraphNodes still exist
+std::vector ConvertGraphNodesToNodePtrs(const GraphNodes& graph_nodes) {
+ std::vector nodes;
+ for (auto& node : graph_nodes) {
+ nodes.push_back(&node);
+ }
+ return nodes;
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/codegen/common/common.h b/onnxruntime/core/codegen/common/common.h
new file mode 100644
index 0000000000000..11ad05325a381
--- /dev/null
+++ b/onnxruntime/core/codegen/common/common.h
@@ -0,0 +1,151 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/framework/compute_capability.h"
+#include "core/framework/tensor.h"
+#include "core/graph/graph_nodes.h"
+#include "core/graph/graph_viewer.h"
+
+#ifndef NDEBUG
+#define ORT_ENFORCE_DEBUG(...) ORT_ENFORCE(__VA_ARGS__)
+#else
+#define ORT_ENFORCE_DEBUG(...)
+#endif // !NDEBUG
+
+// DYN_PROMOTE is a simplified llvm::dyn_cast, which does not need RTTI
+// DYN_PROMOTE is faster than dynamic_cast and also has smaller binary size
+// Please use DYN_PROMOTE in a critical path.
+#define DYN_PROMOTE(BASE) \
+ template \
+ inline const ToType* Promote(const BASE* base) { \
+ if (ToType::IsType(base)) \
+ return static_cast