diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f01a7ab14a61e..f0543f2649205 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -6,6 +6,7 @@ Do not modify directly.*
* com.microsoft.Attention
* com.microsoft.AttnLSTM
* com.microsoft.BeamSearch
+ * com.microsoft.BiasAdd
* com.microsoft.BiasDropout
* com.microsoft.BiasGelu
* com.microsoft.BiasSoftmax
@@ -468,6 +469,40 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.BiasAdd**
+
+ Add input with bias, then add residual inputs.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Inputs
+
+
+- X : T
+- Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels
+- bias : T
+- Bias tensor. Dimensions are (C)
+- skip : T
+- Residual tensor. Dimensions are (N, S, C)
+
+
+#### Outputs
+
+
+- Y : T
+- The output tensor with dimensions (N, S, C)
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float)
+- Constrain input and output types to float tensors.
+
+
+
### **com.microsoft.BiasDropout**
output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 00b71d2946215..08178f206568e 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -787,6 +787,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
+|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index cabf890ab341d..1cefd44844f39 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -21,6 +21,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
@@ -145,107 +147,109 @@ KernelCreateInfo BuildKernelCreateInfo() {
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
- BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo, // backward compatibility
- BuildKernelCreateInfo, // backward compatibility
- BuildKernelCreateInfo, // backward compatibility
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo, // backward compatibility
+ BuildKernelCreateInfo, // backward compatibility
+ BuildKernelCreateInfo, // backward compatibility
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
- // These ops were experimental ops in onnx domain which have been removed now. We add them here as
- // contrib ops to maintain backward compatibility
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ // These ops were experimental ops in onnx domain which have been removed now. We add them here as
+ // contrib ops to maintain backward compatibility
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc
new file mode 100644
index 0000000000000..5d5183221eda4
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/diffusion/bias_add.h"
+#include "contrib_ops/cuda/diffusion/bias_add_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ BiasAdd, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ BiasAdd);
+
+REGISTER_KERNEL_TYPED(MLFloat16);
+REGISTER_KERNEL_TYPED(float);
+
+using namespace ONNX_NAMESPACE;
+
+template
+BiasAdd::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) {
+}
+
+template
+Status BiasAdd::ComputeInternal(OpKernelContext* context) const {
+ // Input: [batch_size, height*width, channels]
+ // Bias: [channels]
+ // Skip: [batch_size, height*width, channels]
+ // Output: [batch_size, height*width, channels]
+
+ const Tensor* input = context->Input(0);
+
+ const auto& input_dims = input->Shape().GetDims();
+ if (input_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "The input is expected to have 3 dimensions, got ", input_dims.size());
+ }
+
+ if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Number of channels should be 320, 640 or 1280, got ", input_dims[2]);
+ }
+
+ const Tensor* bias = context->Input(1);
+ const auto& bias_dims = bias->Shape().GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "The bias is expected to have 1 dimensions, got ", bias_dims.size());
+ }
+ if (bias_dims[0] != input_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Number of channels in the last dimension of input and bias are not the same");
+ }
+
+ const Tensor* skip = context->Input(2);
+ if (skip->Shape() != input->Shape()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Shape of input and skip (residual) shall be the same");
+ }
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ typedef typename ToCudaType::MappedType CudaT;
+ const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]);
+ LaunchBiasAddKernel(Stream(context), grid_size, static_cast(input_dims[2]),
+ reinterpret_cast(input->Data()),
+ reinterpret_cast(bias->Data()),
+ reinterpret_cast(skip->Data()),
+ reinterpret_cast(output->MutableData()));
+
+ CUDA_RETURN_IF_ERROR(cudaPeekAtLastError());
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h
new file mode 100644
index 0000000000000..6f4904f4c8de9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class BiasAdd final : public CudaKernel {
+ public:
+ BiasAdd(const OpKernelInfo& op_kernel_info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu
new file mode 100644
index 0000000000000..2983cc99e30b1
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu
@@ -0,0 +1,79 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// The CUDA kernel is modified from SeqLen2Spatial plugin of TensorRT 8.5.
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/diffusion/bias_add_impl.h"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, T* output) {
+ int32_t base_offset = blockIdx.x * C + threadIdx.x;
+ int32_t bias_offset = threadIdx.x;
+
+#pragma unroll
+ for (int32_t i = 0; i < C / TPB; ++i) {
+ output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset];
+ base_offset += TPB;
+ bias_offset += TPB;
+ }
+}
+
+template __global__ void BiasAddKernel(float const*, float const*, float const*, float*);
+template __global__ void BiasAddKernel(float const*, float const*, float const*, float*);
+template __global__ void BiasAddKernel(float const*, float const*, float const*, float*);
+template __global__ void BiasAddKernel(half const*, half const*, half const*, half*);
+template __global__ void BiasAddKernel(half const*, half const*, half const*, half*);
+template __global__ void BiasAddKernel(half const*, half const*, half const*, half*);
+
+template
+void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
+ T const* input, T const* bias, T const* residual, T* output) {
+ constexpr int32_t TPB = 320; // thread per block
+ switch (num_channels) {
+ case 320:
+ (BiasAddKernel)<<>>(input, bias, residual, output);
+ break;
+ case 640:
+ (BiasAddKernel)<<>>(input, bias, residual, output);
+ break;
+ case 1280:
+ (BiasAddKernel)<<>>(input, bias, residual, output);
+ break;
+ default:
+ ORT_NOT_IMPLEMENTED("Not implemented");
+ }
+}
+
+template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
+ float const* input, float const* bias, float const* residual, float* output);
+
+template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
+ half const* input, half const* bias, half const* residual, half* output);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h
new file mode 100644
index 0000000000000..d3397ea035959
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/common/status.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
+ T const* input, T const* bias, T const* residual, T* output);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
index 3cb95dad26b36..8069cbc0a1e0e 100644
--- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
@@ -23,6 +23,8 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h"
+using namespace onnxruntime::cuda;
+
namespace onnxruntime {
namespace contrib {
namespace cuda {
@@ -35,13 +37,9 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) {
#pragma unroll
for (int32_t i = 0; i < HHS / TPB; ++i) {
-#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
auto value_left = (float)(input[index_input] + bias[index_bias]);
auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]);
-#else
- auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]);
- auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]);
-#endif
+
// Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0)
float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f);
float result = value_left * gelu_right;
diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
index 14a267357371d..c6d3db7fbe6da 100644
--- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
@@ -111,5 +111,32 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
updateOutputShape(ctx, 0, output_shape);
}
}));
+
+constexpr const char* BiasAdd_ver1_doc = R"DOC(
+Add input with bias, then add residual inputs.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ BiasAdd, 1,
+ OpSchema()
+ .SetDoc(BiasAdd_ver1_doc)
+ .Input(0,
+ "X",
+ "Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels",
+ "T")
+ .Input(1,
+ "bias",
+ "Bias tensor. Dimensions are (C)",
+ "T")
+ .Input(2,
+ "skip",
+ "Residual tensor. Dimensions are (N, S, C)",
+ "T")
+ .Output(0,
+ "Y",
+ "The output tensor with dimensions (N, S, C)",
+ "T")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index bd8469909fe7f..548f0a7ecc353 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -50,6 +50,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist);
@@ -139,6 +140,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 842d5cd943226..53106715cd36c 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -203,6 +203,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
"BiasSplitGelu": self._infer_BiasSplitGelu,
+ "BiasAdd": self._infer_BiasAdd,
"NhwcConv": self._infer_NhwcConv,
}
self.aten_op_dispatcher_ = {
@@ -443,6 +444,7 @@ def _onnx_infer_single_node(self, node):
"MultiHeadAttention",
"GroupNorm",
"BiasSplitGelu",
+ "BiasAdd",
"NhwcConv",
]
@@ -2104,6 +2106,9 @@ def _infer_BiasSplitGelu(self, node):
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
+ def _infer_BiasAdd(self, node):
+ self._propagate_shape_and_type(node)
+
def _infer_PythonOp(self, node):
output_tensor_types = get_attribute(node, "output_tensor_types")
assert output_tensor_types
diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py
index 437e72fce0a31..a7904c39f8491 100644
--- a/onnxruntime/python/tools/transformers/float16.py
+++ b/onnxruntime/python/tools/transformers/float16.py
@@ -111,7 +111,6 @@ def make_value_info_from_tensor(tensor):
"NonMaxSuppression",
"TopK",
"RoiAlign",
- "Resize",
"Range",
"CumSum",
"Min",
@@ -120,6 +119,10 @@ def make_value_info_from_tensor(tensor):
]
+# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
+ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]}
+
+
class InitializerTracker:
"""Class for keeping track of initializer."""
@@ -198,6 +201,12 @@ def convert_float_to_float16(
queue = []
value_info_list = []
node_list = []
+
+ # Some operators (Like Resize or GroupNorm) have data type fixed as float for some input.
+ # When it is converted to float16, there are mixed types: some inputs are float32 and some are float16.
+ # This list keeps track of such nodes that are not in block list.
+ mixed_float_type_node_list = []
+
# type inference on input model
if func_infer_shape is not None:
model = func_infer_shape(model)
@@ -276,9 +285,11 @@ def convert_float_to_float16(
n.output[i] = name_mapping[n.output[i]]
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list
- for input in n.input:
- if input in fp32_initializers:
- fp32_initializers[input].add_node(n, is_node_blocked)
+ for i, input_name in enumerate(n.input):
+ if input_name in fp32_initializers:
+ # For Resize/GroupNorm, only the first input can be float16
+ use_fp32_weight = is_node_blocked or (n.op_type in ["Resize", "GroupNorm"] and i != 0)
+ fp32_initializers[input_name].add_node(n, use_fp32_weight)
if is_node_blocked:
node_list.append(n)
@@ -288,8 +299,14 @@ def convert_float_to_float16(
if attr.name == "to" and attr.i == 1:
attr.i = 10
break
- for attr in n.attribute:
- next_level.append(attr)
+
+ # For Resize/GroupNorm, attribute data type cannot be changed
+ if n.op_type not in ["Resize", "GroupNorm"]:
+ for attr in n.attribute:
+ next_level.append(attr)
+ else:
+ mixed_float_type_node_list.append(n)
+
# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
# and process node.attribute.t and node.attribute.tensors (TensorProto)
if isinstance(q, onnx_proto.AttributeProto):
@@ -329,15 +346,36 @@ def convert_float_to_float16(
)
)
+ # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
+ for node in mixed_float_type_node_list:
+ for i, input_name in enumerate(node.input):
+ if i not in ALWAYS_FLOAT_INPUTS[node.op_type]:
+ continue
+ for value_info in value_info_list:
+ if input_name == value_info.name:
+ # create new value_info for current node's new input name
+ new_value_info = model.graph.value_info.add()
+ new_value_info.CopyFrom(value_info)
+ output_name = node.name + "_input_cast_" + str(i)
+ new_value_info.name = output_name
+ new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
+ # add Cast node (from tensor(float16) to tensor(float) before current node
+ node_name = node.name + "_input_cast" + str(i)
+ new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
+ model.graph.node.extend(new_node)
+ # change current node's input name
+ node.input[i] = output_name
+ break
+
# process the nodes in block list that doesn't support tensor(float16)
for node in node_list:
# if input's name is in the value_info_list meaning input is tensor(float16) type,
# insert a float16 to float Cast node before the node,
# change current node's input name and create new value_info for the new name
for i in range(len(node.input)):
- input = node.input[i]
+ input_name = node.input[i]
for value_info in value_info_list:
- if input == value_info.name:
+ if input_name == value_info.name:
# create new value_info for current node's new input name
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(value_info)
@@ -346,7 +384,7 @@ def convert_float_to_float16(
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
- new_node = [helper.make_node("Cast", [input], [output_name], to=1, name=node_name)]
+ new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
model.graph.node.extend(new_node)
# change current node's input name
node.input[i] = output_name
diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
index 0441ce494d560..361baaedc4a95 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
@@ -92,9 +92,6 @@ def create_attention_node(
q_matmul (NodeProto): MatMul node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
- q_add (NodeProto): Add bias node in fully connection for Q
- k_add (NodeProto): Add bias node in fully connection for K
- v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
diff --git a/onnxruntime/python/tools/transformers/fusion_attention_vae.py b/onnxruntime/python/tools/transformers/fusion_attention_vae.py
new file mode 100644
index 0000000000000..e91a8a61fcc24
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/fusion_attention_vae.py
@@ -0,0 +1,304 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+from logging import getLogger
+from typing import Tuple, Union
+
+import numpy as np
+from fusion_base import Fusion
+from onnx import NodeProto, TensorProto, helper, numpy_helper
+from onnx_model import OnnxModel
+
+logger = getLogger(__name__)
+
+
+class FusionAttentionVae(Fusion):
+ """
+ Fuse Attention subgraph of Vae Decoder into one Attention node.
+ """
+
+ def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int):
+ super().__init__(model, "Attention", ["Softmax"])
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+
+ # Flags to show warning only once
+ self.num_heads_warning = True
+ self.hidden_size_warning = True
+
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]:
+ """Detect num_heads and hidden_size from a reshape node.
+
+ Args:
+ reshape_q (NodeProto): reshape node for Q
+ add_q (NodeProto): add node for Q
+
+ Returns:
+ Tuple[int, int]: num_heads and hidden_size
+ """
+ concat = self.model.get_parent(reshape_q, 1)
+ if concat is None or len(concat.input) != 4:
+ return self.num_heads, self.hidden_size # Fall back to user specified value
+
+ value = self.model.get_constant_value(concat.input[2])
+ if not (value is not None and isinstance(value, np.ndarray) and value.size == 1):
+ return self.num_heads, self.hidden_size # Fall back to user specified value
+ num_heads = int(value)
+ if num_heads <= 0:
+ return self.num_heads, self.hidden_size # Fall back to user specified value
+
+ _, bias = self.model.get_constant_input(add_q)
+ if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1:
+ return self.num_heads, self.hidden_size # Fall back to user specified value
+
+ hidden_size = bias.shape[0]
+
+ if self.num_heads > 0 and num_heads != self.num_heads:
+ if self.num_heads_warning:
+ logger.warning(
+ "Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads
+ )
+ self.num_heads_warning = False # Do not show the warning more than once
+
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
+ if self.hidden_size_warning:
+ logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size)
+ self.hidden_size_warning = False # Do not show the warning more than once
+
+ return num_heads, hidden_size
+
+ def create_attention_node(
+ self,
+ q_matmul: NodeProto,
+ q_add: NodeProto,
+ k_matmul: NodeProto,
+ k_add: NodeProto,
+ v_matmul: NodeProto,
+ v_add: NodeProto,
+ num_heads: int,
+ hidden_size: int,
+ input_name: str,
+ output_name: str,
+ ) -> Union[NodeProto, None]:
+ """Create an Attention node.
+
+ Args:
+ q_matmul (NodeProto): MatMul node in fully connection for Q
+ q_add (NodeProto): Add bias node in fully connection for Q
+ k_matmul (NodeProto): MatMul node in fully connection for K
+ k_add (NodeProto): Add bias node in fully connection for K
+ v_matmul (NodeProto): MatMul node in fully connection for V
+ v_add (NodeProto): Add bias node in fully connection for V
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
+ input_name (str): input name
+ output_name (str): output name
+
+ Returns:
+ Union[NodeProto, None]: the node created or None if failed.
+ """
+ if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name:
+ logger.debug(
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
+ q_matmul.input[0],
+ k_matmul.input[0],
+ v_matmul.input[0],
+ )
+ return None
+
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
+ logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
+ return None
+
+ q_weight_tensor = self.model.get_initializer(q_matmul.input[1])
+ k_weight_tensor = self.model.get_initializer(k_matmul.input[1])
+ v_weight_tensor = self.model.get_initializer(v_matmul.input[1])
+ if not (q_weight_tensor and k_weight_tensor and v_weight_tensor):
+ return None
+
+ q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
+ k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
+ v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
+
+ q_bias = numpy_helper.to_array(q_bias_tensor)
+ k_bias = numpy_helper.to_array(k_bias_tensor)
+ v_bias = numpy_helper.to_array(v_bias_tensor)
+
+ q_bias_shape = np.prod(q_bias.shape)
+ k_bias_shape = np.prod(k_bias.shape)
+ v_bias_shape = np.prod(v_bias.shape)
+
+ # Sometimes weights are stored in fp16
+ if q_weight_tensor.data_type == 10:
+ logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
+ return None
+
+ q_weight = numpy_helper.to_array(q_weight_tensor)
+ k_weight = numpy_helper.to_array(k_weight_tensor)
+ v_weight = numpy_helper.to_array(v_weight_tensor)
+
+ # assert q and k have same shape as expected
+ if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape:
+ return None
+
+ qw_in_size = q_weight.shape[0]
+ kw_in_size = k_weight.shape[0]
+ vw_in_size = v_weight.shape[0]
+
+ assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
+
+ if hidden_size > 0 and hidden_size != qw_in_size:
+ raise ValueError(
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
+ "Please provide a correct input hidden size or pass in 0"
+ )
+
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
+ # For 2d weights, the shapes would be [in_size, out_size].
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
+ qw_out_size = np.prod(q_weight.shape[1:])
+
+ qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1)
+ qkv_weight_dim = 3 * int(qw_out_size)
+
+ attention_node_name = self.model.create_node_name("Attention")
+
+ assert q_bias_shape == k_bias_shape == v_bias_shape
+
+ qkv_bias_dim = 0
+ qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0)
+ qkv_bias_dim = 3 * q_bias_shape
+
+ weight = helper.make_tensor(
+ name=attention_node_name + "_qkv_weight",
+ data_type=TensorProto.FLOAT,
+ dims=[qw_in_size, qkv_weight_dim],
+ vals=qkv_weight.flatten().tolist(),
+ )
+
+ self.model.add_initializer(weight, self.this_graph_name)
+
+ # No bias, use zeros
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
+ qkv_bias_dim = 3 * hidden_size
+
+ bias = helper.make_tensor(
+ name=attention_node_name + "_qkv_bias",
+ data_type=TensorProto.FLOAT,
+ dims=[qkv_bias_dim],
+ vals=qkv_bias.flatten().tolist(),
+ )
+ self.model.add_initializer(bias, self.this_graph_name)
+
+ attention_inputs = [
+ input_name,
+ attention_node_name + "_qkv_weight",
+ attention_node_name + "_qkv_bias",
+ ]
+
+ attention_node = helper.make_node(
+ "Attention",
+ inputs=attention_inputs,
+ outputs=[output_name],
+ name=attention_node_name,
+ )
+ attention_node.domain = "com.microsoft"
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
+
+ self.increase_counter("Attention (self attention)")
+ return attention_node
+
+ def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node):
+ matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False)
+ if matmul_qkv is None:
+ return
+
+ reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False)
+ if reshape_qkv is None:
+ return
+
+ transpose_qkv = self.model.find_first_child_by_type(
+ reshape_qkv, "Transpose", input_name_to_nodes, recursive=False
+ )
+ if transpose_qkv is None:
+ return
+
+ reshape_out = self.model.find_first_child_by_type(
+ transpose_qkv, "Reshape", input_name_to_nodes, recursive=False
+ )
+ if reshape_out is None:
+ return
+
+ matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False)
+ if matmul_out is None:
+ return
+
+ add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False)
+ if add_out is None:
+ return
+
+ transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False)
+ if transpose_out is None:
+ return
+
+ v_nodes = self.model.match_parent_path(
+ matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
+ )
+ if v_nodes is None:
+ logger.debug("fuse_attention: failed to match v path")
+ return
+ (_, _, _, add_v, matmul_v) = v_nodes
+
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
+ if qk_nodes is not None:
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
+ else:
+ logger.debug("fuse_attention: failed to match qk path")
+ return
+
+ q_nodes = self.model.match_parent_path(
+ matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None]
+ )
+ if q_nodes is None:
+ logger.debug("fuse_attention: failed to match q path")
+ return
+ (_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
+ k_nodes = self.model.match_parent_path(
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
+ )
+ if k_nodes is None:
+ logger.debug("fuse_attention: failed to match k path")
+ return
+ (_, _, _, _, add_k, matmul_k) = k_nodes
+
+ attention_last_node = reshape_out
+
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q)
+ if q_num_heads <= 0:
+ logger.debug("fuse_attention: failed to detect num_heads")
+ return
+
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
+ new_node = self.create_attention_node(
+ matmul_q,
+ add_q,
+ matmul_k,
+ add_k,
+ matmul_v,
+ add_v,
+ q_num_heads,
+ q_hidden_size,
+ matmul_q.input[0],
+ attention_last_node.output[0],
+ )
+ if new_node is None:
+ return
+
+ self.nodes_to_add.append(new_node)
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
+
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
+
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
+ self.prune_graph = True
diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py
index f338c0a8c4438..43969aa04bf20 100644
--- a/onnxruntime/python/tools/transformers/fusion_base.py
+++ b/onnxruntime/python/tools/transformers/fusion_base.py
@@ -52,11 +52,11 @@ def apply(self):
if self.fused_count:
for key, value in self.fused_count.items():
if value:
- logger.info(f"Fused {key} count: {value}")
+ logger.info(f"Fused {key}: {value}")
else:
count = op_list.count(self.fused_op_type)
if count > 0:
- logger.info(f"Fused {self.description} count: {count}")
+ logger.info(f"Fused {self.description}: {count}")
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
diff --git a/onnxruntime/python/tools/transformers/fusion_bias_add.py b/onnxruntime/python/tools/transformers/fusion_bias_add.py
new file mode 100644
index 0000000000000..cdf54a3629726
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/fusion_bias_add.py
@@ -0,0 +1,58 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+from logging import getLogger
+from typing import Dict
+
+from fusion_base import Fusion
+from numpy import ndarray
+from onnx import helper
+from onnx_model import OnnxModel
+
+logger = getLogger(__name__)
+
+
+class FusionBiasAdd(Fusion):
+ def __init__(self, model: OnnxModel):
+ super().__init__(model, "BiasAdd", "Add")
+
+ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
+ """
+ Fuse Add bias and Add skip connection into BiasAdd
+ """
+
+ nodes = self.model.match_parent_path(
+ add_node,
+ ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
+ [0, None, 0, 0, 0],
+ output_name_to_node,
+ )
+
+ if nodes is None:
+ return
+
+ bias_node = nodes[0]
+ skip_layer_norm = nodes[-1]
+
+ # Check skip connection is from SkipLayerNormalization output
+ if not (add_node.input[1] in skip_layer_norm.output):
+ return
+
+ bias_index, bias_value = self.model.get_constant_input(bias_node)
+ if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
+ return
+ if bias_value.ndim != 1:
+ return
+
+ self.nodes_to_remove.extend([add_node, bias_node])
+ node_name = self.model.create_node_name("BiasAdd")
+ fused_node = helper.make_node(
+ "BiasAdd",
+ inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
+ outputs=[add_node.output[0]],
+ name=node_name,
+ )
+ fused_node.domain = "com.microsoft"
+ self.nodes_to_add.append(fused_node)
+ self.node_name_to_graph_name[node_name] = self.this_graph_name
diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py
index a0a4d7c16de0b..c04c806eaaa86 100644
--- a/onnxruntime/python/tools/transformers/fusion_group_norm.py
+++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py
@@ -7,7 +7,6 @@
import numpy as np
from fusion_base import Fusion
-from fusion_utils import FusionUtils
from onnx import TensorProto, helper
from onnx_model import OnnxModel
@@ -143,35 +142,22 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
# instance_norm_scale might from Constant node. Use prune graph to clear it.
self.prune_graph = True
- # Right now GroupNorm only support float16 input. Need add a Cast in fp32 model.
- utils = FusionUtils(self.model)
-
- input = root
- output = last_node.output[0]
- if weight.dtype == np.float32:
- # Add a Cast node to get float16 input for GroupNorm
- cast_input, _cast_node = utils.cast_input(root, "float16")
- input = cast_input
-
- # Add a Cast node to convert back to float32 after GroupNorm
- output = group_norm_name + "_out"
- cast_node = helper.make_node("Cast", inputs=[group_norm_name + "_out"], outputs=[last_node.output[0]])
- cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))])
- self.model.add_node(cast_node)
+ input_name = root
+ output_name = last_node.output[0]
# NCHW to NHWC
transpose_input = helper.make_node(
"Transpose",
- [input],
- [input + "_NHWC"],
+ [input_name],
+ [input_name + "_NHWC"],
name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"),
perm=[0, 2, 3, 1],
)
new_node = helper.make_node(
"GroupNorm",
- inputs=[input + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"],
- outputs=[output + "_NHWC"],
+ inputs=[input_name + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"],
+ outputs=[output_name + "_NHWC"],
name=group_norm_name,
)
@@ -183,8 +169,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
# NHWC to NCHW
transpose_output = helper.make_node(
"Transpose",
- [output + "_NHWC"],
- [output],
+ [output_name + "_NHWC"],
+ [output_name],
name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"),
perm=[0, 3, 1, 2],
)
diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py
index cdfa2c626fc57..7629a6886ab94 100644
--- a/onnxruntime/python/tools/transformers/fusion_options.py
+++ b/onnxruntime/python/tools/transformers/fusion_options.py
@@ -51,9 +51,10 @@ def __init__(self, model_type):
)
# options for stable diffusion
- self.enable_group_norm = model_type == "unet"
- self.enable_bias_splitgelu = model_type == "unet"
- self.enable_packed_kv = model_type == "unet"
+ self.enable_group_norm = model_type in ["unet", "vae"]
+ self.enable_bias_splitgelu = model_type in ["unet"]
+ self.enable_packed_kv = model_type in ["unet"]
+ self.enable_bias_add = model_type in ["unet"]
def use_raw_attention_mask(self, use_raw_mask=True):
if use_raw_mask:
diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py
index d92ddd5f8e678..84961f799a122 100644
--- a/onnxruntime/python/tools/transformers/fusion_transpose.py
+++ b/onnxruntime/python/tools/transformers/fusion_transpose.py
@@ -25,6 +25,9 @@ def fuse(
output_name_to_node: Dict[str, NodeProto],
):
"""
+ Note that onnxruntime will do comprehensive transpose optimization after loading model.
+ The purpose of this fusion is to make graph clean before running onnxruntime.
+
Case 1:
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
After:
@@ -38,8 +41,6 @@ def fuse(
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
+----->Cast --> Transpose(perm=a*b)-->
-
-
"""
transpose_b = transpose_node
if transpose_b.input[0] not in output_name_to_node:
diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py
index 07fdf490337a4..0945be6cc6898 100644
--- a/onnxruntime/python/tools/transformers/fusion_utils.py
+++ b/onnxruntime/python/tools/transformers/fusion_utils.py
@@ -74,7 +74,23 @@ def remove_cast_int32(self, input_name: str):
self.model.replace_input_of_all_nodes(output_name, input_name)
@staticmethod
- def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
+ def update_node_input(node, i, new_input_name, input_name_to_nodes):
+ old_input_reference = 0
+ if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
+ input_name_to_nodes[node.input[i]].remove(node)
+ old_input_reference = len(input_name_to_nodes[node.input[i]])
+
+ node.input[i] = new_input_name
+
+ if new_input_name in input_name_to_nodes:
+ input_name_to_nodes[new_input_name].append(node)
+ else:
+ input_name_to_nodes[new_input_name] = [node]
+
+ return old_input_reference
+
+ @staticmethod
+ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
"""
Before:
(input)-->parent-->node-->(output)
@@ -83,20 +99,16 @@ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes):
|
+----->node-->(output)
- This function returns a flag about whether the parent node can be removed.
- Note that this function assumes the node has first input links from parent!
+ This function returns a flag whether the parent node can be removed.
"""
- parent_can_be_removed = False
- input_name_to_nodes[node.input[0]].remove(node)
+
+ old_input_name = node.input[node_input_index]
+ new_input_name = parent_node.input[parent_input_index]
+ old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
+
# We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
- if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output(
- node.input[0]
- ): # checks main graph output. TODO: deal with subgraph
- parent_can_be_removed = True
- # self.nodes_to_remove.append(transpose_a)
-
- input_name_to_nodes[parent_node.input[0]].append(node)
- node.input[0] = parent_node.input[0]
+ parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
+
return parent_can_be_removed
@staticmethod
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
index 9a00dc8684f32..236923d5ff496 100755
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
@@ -14,10 +14,7 @@
}
-def get_test_settings():
- height = 512
- width = 512
- num_inference_steps = 50
+def example_prompts():
prompts = [
"a photo of an astronaut riding a horse on mars",
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
@@ -31,11 +28,11 @@ def get_test_settings():
"delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k",
]
- return height, width, num_inference_steps, prompts
+ return prompts
def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_safety_checker: bool):
- from diffusers import OnnxStableDiffusionPipeline
+ from diffusers import DPMSolverMultistepScheduler, OnnxStableDiffusionPipeline
import onnxruntime
@@ -54,6 +51,8 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf
provider=provider,
use_auth_token=True,
)
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ pipe.set_progress_bar_config(disable=True)
if disable_safety_checker:
pipe.safety_checker = None
@@ -63,14 +62,15 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf
def get_torch_pipeline(model_name: str, disable_safety_checker: bool):
- from diffusers import StableDiffusionPipeline
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
from torch import channels_last, float16
pipe = StableDiffusionPipeline.from_pretrained(
model_name, torch_dtype=float16, revision="fp16", use_auth_token=True
).to("cuda")
-
pipe.unet.to(memory_format=channels_last) # in-place operation
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ pipe.set_progress_bar_config(disable=True)
if disable_safety_checker:
pipe.safety_checker = None
@@ -84,66 +84,111 @@ def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, dis
return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe")
-def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str):
+def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count):
from diffusers import OnnxStableDiffusionPipeline
assert isinstance(pipe, OnnxStableDiffusionPipeline)
- height, width, num_inference_steps, prompts = get_test_settings()
+ prompts = example_prompts()
- pipe("warm up", height, width, num_inference_steps=2)
+ pipe("warm up", height, width, num_inference_steps=steps)
latency_list = []
for i, prompt in enumerate(prompts):
- input_prompts = [prompt] * batch_size
- inference_start = time.time()
- image = pipe(input_prompts, height, width, num_inference_steps).images[0]
- inference_end = time.time()
-
- latency = inference_end - inference_start
- latency_list.append(latency)
- print(f"Inference took {latency} seconds")
- image.save(f"{image_filename_prefix}_{i}.jpg")
+ if i >= num_prompts:
+ break
+ for j in range(batch_count):
+ inference_start = time.time()
+ images = pipe(
+ prompt,
+ height,
+ width,
+ num_inference_steps=steps,
+ negative_prompt=None,
+ guidance_scale=7.5,
+ num_images_per_prompt=batch_size,
+ ).images
+ inference_end = time.time()
+ latency = inference_end - inference_start
+ latency_list.append(latency)
+ print(f"Inference took {latency} seconds")
+ for k, image in enumerate(images):
+ image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg")
+
print("Average latency in seconds:", sum(latency_list) / len(latency_list))
-def run_torch_pipeline(pipe, batch_size: int, image_filename_prefix: str):
+def run_torch_pipeline(
+ pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count
+):
import torch
- height, width, num_inference_steps, prompts = get_test_settings()
+ prompts = example_prompts()
- pipe("warm up", height, width, num_inference_steps=2)
+ pipe("warm up", height, width, num_inference_steps=steps)
torch.set_grad_enabled(False)
latency_list = []
for i, prompt in enumerate(prompts):
- input_prompts = [prompt] * batch_size
- torch.cuda.synchronize()
- inference_start = time.time()
- image = pipe(input_prompts, height, width, num_inference_steps).images[0]
+ if i >= num_prompts:
+ break
torch.cuda.synchronize()
- inference_end = time.time()
-
- latency = inference_end - inference_start
- latency_list.append(latency)
- print(f"Inference took {latency} seconds")
- image.save(f"{image_filename_prefix}_{i}.jpg")
+ for j in range(batch_count):
+ inference_start = time.time()
+ images = pipe(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=steps,
+ guidance_scale=7.5,
+ negative_prompt=None,
+ num_images_per_prompt=batch_size,
+ generator=None, # torch.Generator
+ ).images
+
+ torch.cuda.synchronize()
+ inference_end = time.time()
+ latency = inference_end - inference_start
+ latency_list.append(latency)
+ print(f"Inference took {latency} seconds")
+ for k, image in enumerate(images):
+ image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg")
print("Average latency in seconds:", sum(latency_list) / len(latency_list))
-def run_ort(model_name: str, directory: str, provider: str, batch_size: int, disable_safety_checker: bool):
+def run_ort(
+ model_name: str,
+ directory: str,
+ provider: str,
+ batch_size: int,
+ disable_safety_checker: bool,
+ height,
+ width,
+ steps,
+ num_prompts,
+ batch_count,
+):
load_start = time.time()
pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker)
load_end = time.time()
print(f"Model loading took {load_end - load_start} seconds")
image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker)
- run_ort_pipeline(pipe, batch_size, image_filename_prefix)
-
-
-def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool):
+ run_ort_pipeline(pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count)
+
+
+def run_torch(
+ model_name: str,
+ batch_size: int,
+ disable_safety_checker: bool,
+ height,
+ width,
+ steps,
+ num_prompts,
+ batch_count,
+):
import torch
torch.backends.cudnn.enabled = True
@@ -159,7 +204,7 @@ def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool):
image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker)
with torch.inference_mode():
- run_torch_pipeline(pipe, batch_size, image_filename_prefix)
+ run_torch_pipeline(pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count)
def parse_arguments():
@@ -201,7 +246,58 @@ def parse_arguments():
)
parser.set_defaults(enable_safety_checker=False)
- parser.add_argument("-b", "--batch_size", type=int, default=1)
+ parser.add_argument(
+ "-b",
+ "--batch_size",
+ type=int,
+ default=1,
+ choices=[1, 2, 4, 8, 16, 32],
+ help="Number of images per batch",
+ )
+
+ parser.add_argument(
+ "--height",
+ required=False,
+ type=int,
+ default=512,
+ help="Output image height",
+ )
+
+ parser.add_argument(
+ "--width",
+ required=False,
+ type=int,
+ default=512,
+ help="Output image width",
+ )
+
+ parser.add_argument(
+ "-s",
+ "--steps",
+ required=False,
+ type=int,
+ default=50,
+ help="Number of steps",
+ )
+
+ parser.add_argument(
+ "-n",
+ "--num_prompts",
+ required=False,
+ type=int,
+ default=1,
+ help="Number of prompts",
+ )
+
+ parser.add_argument(
+ "-c",
+ "--batch_count",
+ required=False,
+ type=int,
+ choices=range(1, 11),
+ default=10,
+ help="Number of batches to test",
+ )
args = parser.parse_args()
return args
@@ -219,13 +315,33 @@ def main():
# Need remove a line https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L307
# in diffuers to run batch_size > 1.
assert (
- args.enable_safety_checker
+ not args.enable_safety_checker
), "batch_size > 1 is not compatible with safety checker due to a bug in diffuers"
provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers
- run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker)
+ run_ort(
+ sd_model,
+ args.pipeline,
+ provider,
+ args.batch_size,
+ not args.enable_safety_checker,
+ args.height,
+ args.width,
+ args.steps,
+ args.num_prompts,
+ args.batch_count,
+ )
else:
- run_torch(sd_model, args.batch_size, not args.enable_safety_checker)
+ run_torch(
+ sd_model,
+ args.batch_size,
+ not args.enable_safety_checker,
+ args.height,
+ args.width,
+ args.steps,
+ args.num_prompts,
+ args.batch_count,
+ )
if __name__ == "__main__":
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
index 932be4a19ae6b..31b2a22c2f615 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
@@ -10,35 +10,54 @@
# pip install -r requirements.txt
# huggingface-cli login
# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
-# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32
+# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/sd-v1-5
+# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/sd-v2-1
# Note that this script might not be compatible with older or newer version of diffusers.
# Then you can use this script to convert them to float16 like the following:
-# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
+# python optimize_pipeline.py -i $ONNX_ROOT/sd-v1-5 -o $ONNX_ROOT/sd-v1-5-fp16 --float16
# Or
-# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16
+# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/sd-v1-5 -o $ONNX_ROOT/sd-v1-5-fp16 --float16
#
# Note that output model is for CUDA Execution Provider. It might not run in CPU Execution Provider.
-# Stable diffusion 2.1 model will get black images using float16 Attention. It is a known issue that we are working on.
+#
+# Stable diffusion 2.1 model will get black images using float16 Attention. A walkaround is to force it in float32:
+# python optimize_pipeline.py -i $ONNX_ROOT/sd-v2-1 -o $ONNX_ROOT/sd-v2-1-fp16 --float16 --force_fp32_ops unet:Attention
+
import argparse
import logging
import os
import shutil
import sys
+import tempfile
from pathlib import Path
+from typing import List
import coloredlogs
+import onnx
+from packaging import version
+
+import onnxruntime
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from fusion_options import FusionOptions
-from optimizer import optimize_model # noqa: E402
+from onnx_model_clip import ClipOnnxModel
+from onnx_model_unet import UnetOnnxModel
+from onnx_model_vae import VaeOnnxModel
+from optimizer import optimize_by_onnxruntime, optimize_model
logger = logging.getLogger(__name__)
def optimize_sd_pipeline(
- source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool
+ source_dir: Path,
+ target_dir: Path,
+ overwrite: bool,
+ use_external_data_format: bool,
+ float16: bool,
+ force_fp32_ops: List[str],
+ enable_runtime_optimization: bool,
):
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
@@ -48,54 +67,115 @@ def optimize_sd_pipeline(
overwrite (bool): Overwrite files if exists.
use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights
float16 (bool): use half precision
+ force_fp32_ops(List[str]): operators that are forced to run in float32.
+ enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
Raises:
RuntimeError: input onnx model does not exist
RuntimeError: output onnx model path existed
"""
- dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"]
- for name in dirs_with_onnx:
+ model_type_mapping = {
+ "unet": "unet",
+ "vae_encoder": "vae",
+ "vae_decoder": "vae",
+ "text_encoder": "clip",
+ "safety_checker": "unet",
+ }
+
+ model_type_class_mapping = {
+ "unet": UnetOnnxModel,
+ "vae": VaeOnnxModel,
+ "clip": ClipOnnxModel,
+ }
+
+ force_fp32_operators = {
+ "unet": [],
+ "vae_encoder": [],
+ "vae_decoder": [],
+ "text_encoder": [],
+ "safety_checker": [],
+ }
+
+ if force_fp32_ops:
+ for fp32_operator in force_fp32_ops:
+ parts = fp32_operator.split(":")
+ if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
+ force_fp32_operators[parts[0]].append(parts[1])
+ else:
+ raise ValueError(
+ f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
+ )
+
+ for name, model_type in model_type_mapping.items():
onnx_model_path = source_dir / name / "model.onnx"
if not os.path.exists(onnx_model_path):
message = f"input onnx model does not exist: {onnx_model_path}."
- if name not in ["safety_checker", "feature_extractor"]:
+ if name not in ["safety_checker"]:
raise RuntimeError(message)
continue
+ # Prepare output directory
+ optimized_model_path = target_dir / name / "model.onnx"
+ output_dir = optimized_model_path.parent
+ if optimized_model_path.exists():
+ if not overwrite:
+ raise RuntimeError(f"output onnx model path existed: {optimized_model_path}")
+
+ if output_dir.exists():
+ shutil.rmtree(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
- logger.info(f"optimize {onnx_model_path}...")
+ logger.info(f"Optimize {onnx_model_path}...")
+
+ # There are some optimizations that are not avaiable in v1.14 or older version
+ has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0")
- fusion_options = FusionOptions("unet")
+ fusion_options = FusionOptions(model_type)
fusion_options.enable_packed_kv = float16
+ fusion_options.enable_bias_add = has_all_optimizations
m = optimize_model(
str(onnx_model_path),
- model_type="unet",
+ model_type=model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
- use_gpu=False,
+ use_gpu=True,
)
if float16:
- logger.info("convert %s to float16 ...", name)
- m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"])
-
- optimized_model_path = target_dir / name / "model.onnx"
- output_dir = optimized_model_path.parent
- if optimized_model_path.exists():
- if not overwrite:
- raise RuntimeError(f"output onnx model path existed: {optimized_model_path}")
-
- if output_dir.exists():
- shutil.rmtree(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
+ logger.info("Convert %s to float16 ...", name)
+ op_block_list = ["RandomNormalLike"]
+ m.convert_float_to_float16(
+ keep_io_types=False,
+ op_block_list=op_block_list + force_fp32_operators[name],
+ )
+
+ if enable_runtime_optimization and (float16 or (name not in ["unet"])):
+ # Use this step to see the final graph that executed by Onnx Runtime.
+ # Note that ORT cannot save model larger than 2GB so we exclude unet float32 model.
+ # This step is optional since it has no impact on performance except model loading time.
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Save to a temporary file so that we can load it with Onnx Runtime.
+ logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
+ tmp_model_path = Path(tmp_dir) / "model.onnx"
+ m.save_model_to_file(str(tmp_model_path))
+ ort_optimized_model_path = tmp_model_path
+ optimize_by_onnxruntime(
+ str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path)
+ )
+ model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
+ m = model_type_class_mapping[model_type](model)
+
+ m.get_operator_statistics()
+ m.get_fused_operator_statistics()
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
- logger.info("%s => %s", onnx_model_path, optimized_model_path)
+ logger.info("%s is optimized", name)
+ logger.info("*" * 20)
def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool):
@@ -117,7 +197,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool):
if not os.path.exists(source_path):
message = f"source path does not exist: {source_path}"
- if name not in ["safety_checker", "feature_extractor"]:
+ if name not in ["feature_extractor"]:
raise RuntimeError(message)
continue
@@ -177,6 +257,22 @@ def parse_arguments():
)
parser.set_defaults(float16=False)
+ parser.add_argument(
+ "--force_fp32_ops",
+ required=False,
+ nargs="+",
+ type=str,
+ help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
+ )
+
+ parser.add_argument(
+ "--inspect",
+ required=False,
+ action="store_true",
+ help="Inspect the optimized graph from Onnx Runtime for debugging purpose. This option has no impact on model performance.",
+ )
+ parser.set_defaults(inspect=False)
+
parser.add_argument(
"--overwrite",
required=False,
@@ -202,9 +298,16 @@ def parse_arguments():
def main():
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
args = parse_arguments()
+ logger.info("Arguments: %s", str(args))
copy_extra_directory(Path(args.input), Path(args.output), args.overwrite)
optimize_sd_pipeline(
- Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16
+ Path(args.input),
+ Path(args.output),
+ args.overwrite,
+ args.use_external_data_format,
+ args.float16,
+ args.force_fp32_ops,
+ args.inspect,
)
diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py
index 276a9428ecf72..c8380c93f591c 100644
--- a/onnxruntime/python/tools/transformers/onnx_model.py
+++ b/onnxruntime/python/tools/transformers/onnx_model.py
@@ -802,11 +802,14 @@ def prune_graph(self, outputs=None):
self.model.graph.input.remove(input)
if input_to_remove or output_to_remove or nodes_to_remove:
- logger.info(
- "Graph pruned: {} inputs, {} outputs and {} nodes are removed".format(
- len(input_to_remove), len(output_to_remove), len(nodes_to_remove)
- )
- )
+ removed = []
+ if input_to_remove:
+ removed.append(f"{len(input_to_remove)} inputs")
+ if output_to_remove:
+ removed.append(f"{len(output_to_remove)} outputs")
+ if nodes_to_remove:
+ removed.append(f"{len(nodes_to_remove)} nodes")
+ logger.info("Removed %s", ", ".join(removed))
self.update_graph()
@@ -1022,6 +1025,18 @@ def get_opset_version(self):
return opset.version
raise RuntimeError("ONNX model has no opset for default domain")
+ def get_operator_statistics(self, include_domain=False):
+ """
+ Returns node count of operators.
+ """
+ op_count = {}
+ for node in self.nodes():
+ op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type
+ op_count[op] = 1 if op not in op_count else (op_count[op] + 1)
+
+ logger.info(f"Operators:{op_count}")
+ return op_count
+
@staticmethod
def has_same_value(tensor1: TensorProto, tensor2: TensorProto) -> bool:
"""Returns True when two tensors have same value.
diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py
new file mode 100644
index 0000000000000..93e8623768067
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py
@@ -0,0 +1,33 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+from logging import getLogger
+
+from onnx import ModelProto
+from onnx_model_unet import UnetOnnxModel
+
+logger = getLogger(__name__)
+
+
+class ClipOnnxModel(UnetOnnxModel):
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
+
+ def get_fused_operator_statistics(self):
+ """
+ Returns node count of fused operators.
+ """
+ op_count = {}
+ ops = [
+ "Attention",
+ "LayerNormalization",
+ "SkipLayerNormalization",
+ ]
+ for op in ops:
+ nodes = self.get_nodes_by_op_type(op)
+ op_count[op] = len(nodes)
+
+ logger.info(f"Optimized operators:{op_count}")
+ return op_count
diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py
index 32a98149825c3..b460d7d8a7f8f 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_unet.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py
@@ -7,6 +7,7 @@
from typing import Optional
from fusion_attention_unet import FusionAttentionUnet
+from fusion_bias_add import FusionBiasAdd
from fusion_biassplitgelu import FusionBiasSplitGelu
from fusion_group_norm import FusionGroupNorm
from fusion_nhwc_conv import FusionNhwcConv
@@ -36,7 +37,6 @@ def preprocess(self):
self.remove_useless_div()
def postprocess(self):
- self.merge_sequential_transpose()
self.prune_graph()
self.remove_unused_constant()
@@ -54,14 +54,14 @@ def remove_useless_div(self):
if nodes_to_remove:
self.remove_nodes(nodes_to_remove)
- logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove))
+ logger.info("Removed %d Div nodes", len(nodes_to_remove))
def convert_conv_to_nhwc(self):
# Do not update weight here since save external data has a bug
conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False)
conv_to_nhwc_conv.apply()
- def merge_sequential_transpose(self):
+ def merge_adjacent_transpose(self):
fusion_transpose = FusionTranspose(self)
fusion_transpose.apply()
@@ -89,6 +89,20 @@ def merge_sequential_transpose(self):
if total:
logger.info("Removed %d Transpose nodes", total)
+ def fuse_attention(self, options: Optional[FusionOptions] = None):
+ # Self Attention
+ self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False)
+ self_attention_fusion.apply()
+
+ # Cross Attention
+ enable_packed_kv = (options is None) or options.enable_packed_kv
+ cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv)
+ cross_attention_fusion.apply()
+
+ def fuse_bias_add(self):
+ fusion = FusionBiasAdd(self)
+ fusion.apply()
+
def optimize(self, options: Optional[FusionOptions] = None):
if (options is not None) and not options.enable_shape_inference:
self.disable_shape_inference()
@@ -117,12 +131,7 @@ def optimize(self, options: Optional[FusionOptions] = None):
bias_split_gelu_fusion.apply()
if (options is None) or options.enable_attention:
- self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False)
- self_attention_fusion.apply()
-
- enable_packed_kv = (options is None) or options.enable_packed_kv
- cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv)
- cross_attention_fusion.apply()
+ self.fuse_attention()
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()
@@ -141,6 +150,10 @@ def optimize(self, options: Optional[FusionOptions] = None):
if options is not None and options.enable_gelu_approximation:
self.gelu_approximation()
+ self.merge_adjacent_transpose()
+
+ self.fuse_bias_add()
+
self.postprocess()
logger.info(f"opset version: {self.get_opset_version()}")
@@ -153,8 +166,6 @@ def get_fused_operator_statistics(self):
ops = [
"Attention",
"MultiHeadAttention",
- "Gelu",
- "FastGelu",
"LayerNormalization",
"SkipLayerNormalization",
"BiasSplitGelu",
diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py
new file mode 100644
index 0000000000000..47d3f9ddfe6fd
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py
@@ -0,0 +1,42 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+from logging import getLogger
+from typing import Optional
+
+from fusion_attention_vae import FusionAttentionVae
+from fusion_options import FusionOptions
+from onnx import ModelProto
+from onnx_model_unet import UnetOnnxModel
+
+logger = getLogger(__name__)
+
+
+class VaeOnnxModel(UnetOnnxModel):
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
+
+ def fuse_attention(self, options: Optional[FusionOptions] = None):
+ # Self Attention
+ self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
+ self_attention_fusion.apply()
+
+ def get_fused_operator_statistics(self):
+ """
+ Returns node count of fused operators.
+ """
+ op_count = {}
+ ops = [
+ "Attention",
+ "GroupNorm",
+ "NhwcConv",
+ ]
+ for op in ops:
+ nodes = self.get_nodes_by_op_type(op)
+ op_count[op] = len(nodes)
+
+ logger.info(f"Optimized operators:{op_count}")
+ return op_count
diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py
index a18535c10591d..65de1bf770992 100644
--- a/onnxruntime/python/tools/transformers/optimizer.py
+++ b/onnxruntime/python/tools/transformers/optimizer.py
@@ -29,10 +29,12 @@
from onnx_model_bert import BertOnnxModel
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_bert_tf import BertOnnxModelTF
+from onnx_model_clip import ClipOnnxModel
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from onnx_model_unet import UnetOnnxModel
+from onnx_model_vae import VaeOnnxModel
logger = logging.getLogger(__name__)
@@ -49,8 +51,11 @@
0,
), # might add a class for GPT2OnnxModel for TF later.
"tnlr": (TnlrOnnxModel, "pytorch", 1),
- "unet": (UnetOnnxModel, "pytorch", 1),
"t5": (T5OnnxModel, "pytorch", 2),
+ # Stable Diffusion models
+ "unet": (UnetOnnxModel, "pytorch", 1),
+ "vae": (VaeOnnxModel, "pytorch", 1),
+ "clip": (ClipOnnxModel, "pytorch", 1),
}
@@ -152,7 +157,7 @@ def optimize_by_fusion(
Returns:
object of an optimizer class.
"""
- if model_type not in ["bert", "unet"] and (num_heads == 0 or hidden_size == 0):
+ if model_type not in ["bert", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0):
logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}")
(optimizer_class, producer, _) = MODEL_TYPES[model_type]
@@ -446,10 +451,11 @@ def main():
optimizer.save_model_to_file(args.output, args.use_external_data_format)
- if optimizer.is_fully_optimized():
- logger.info("The model has been fully optimized.")
- else:
- logger.info("The model has been optimized.")
+ if args.model_type in ["bert", "gpt2"]:
+ if optimizer.is_fully_optimized():
+ logger.info("The model has been fully optimized.")
+ else:
+ logger.info("The model has been optimized.")
if __name__ == "__main__":
diff --git a/onnxruntime/test/contrib_ops/bias_add_op_test.cc b/onnxruntime/test/contrib_ops/bias_add_op_test.cc
new file mode 100644
index 0000000000000..733ef81999a10
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/bias_add_op_test.cc
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include "gtest/gtest.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+
+using namespace onnxruntime::test;
+
+namespace onnxruntime {
+namespace test {
+
+#if defined(USE_CUDA) // The operator has only CUDA implementation right now
+static std::vector GetExpectedResult(const std::vector& input_data,
+ const std::vector