diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f01a7ab14a61e..f0543f2649205 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -6,6 +6,7 @@ Do not modify directly.* * com.microsoft.Attention * com.microsoft.AttnLSTM * com.microsoft.BeamSearch + * com.microsoft.BiasAdd * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax @@ -468,6 +469,40 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BiasAdd** + + Add input with bias, then add residual inputs. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
X : T
+
Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels
+
bias : T
+
Bias tensor. Dimensions are (C)
+
skip : T
+
Residual tensor. Dimensions are (N, S, C)
+
+ +#### Outputs + +
+
Y : T
+
The output tensor with dimensions (N, S, C)
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.BiasDropout** output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 00b71d2946215..08178f206568e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -787,6 +787,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| +|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index cabf890ab341d..1cefd44844f39 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -21,6 +21,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); @@ -145,107 +147,109 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to maintain backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc new file mode 100644 index 0000000000000..5d5183221eda4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/bias_add.h" +#include "contrib_ops/cuda/diffusion/bias_add_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BiasAdd, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BiasAdd); + +REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); + +using namespace ONNX_NAMESPACE; + +template +BiasAdd::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) { +} + +template +Status BiasAdd::ComputeInternal(OpKernelContext* context) const { + // Input: [batch_size, height*width, channels] + // Bias: [channels] + // Skip: [batch_size, height*width, channels] + // Output: [batch_size, height*width, channels] + + const Tensor* input = context->Input(0); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The input is expected to have 3 dimensions, got ", input_dims.size()); + } + + if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels should be 320, 640 or 1280, got ", input_dims[2]); + } + + const Tensor* bias = context->Input(1); + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The bias is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in the last dimension of input and bias are not the same"); + } + + const Tensor* skip = context->Input(2); + if (skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Shape of input and skip (residual) shall be the same"); + } + + Tensor* output = context->Output(0, input->Shape()); + + typedef typename ToCudaType::MappedType CudaT; + const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); + LaunchBiasAddKernel(Stream(context), grid_size, static_cast(input_dims[2]), + reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(output->MutableData())); + + CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h new file mode 100644 index 0000000000000..6f4904f4c8de9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class BiasAdd final : public CudaKernel { + public: + BiasAdd(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu new file mode 100644 index 0000000000000..2983cc99e30b1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The CUDA kernel is modified from SeqLen2Spatial plugin of TensorRT 8.5. +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/bias_add_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, T* output) { + int32_t base_offset = blockIdx.x * C + threadIdx.x; + int32_t bias_offset = threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < C / TPB; ++i) { + output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset]; + base_offset += TPB; + bias_offset += TPB; + } +} + +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); +template __global__ void BiasAddKernel(half const*, half const*, half const*, half*); + +template +void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + T const* input, T const* bias, T const* residual, T* output) { + constexpr int32_t TPB = 320; // thread per block + switch (num_channels) { + case 320: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + case 640: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + case 1280: + (BiasAddKernel)<<>>(input, bias, residual, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + float const* input, float const* bias, float const* residual, float* output); + +template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + half const* input, half const* bias, half const* residual, half* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h new file mode 100644 index 0000000000000..d3397ea035959 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, + T const* input, T const* bias, T const* residual, T* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 3cb95dad26b36..8069cbc0a1e0e 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -23,6 +23,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { @@ -35,13 +37,9 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { #pragma unroll for (int32_t i = 0; i < HHS / TPB; ++i) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) auto value_left = (float)(input[index_input] + bias[index_bias]); auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]); -#else - auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]); - auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]); -#endif + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); float result = value_left * gelu_right; diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index 14a267357371d..c6d3db7fbe6da 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -111,5 +111,32 @@ ONNX_MS_OPERATOR_SET_SCHEMA( updateOutputShape(ctx, 0, output_shape); } })); + +constexpr const char* BiasAdd_ver1_doc = R"DOC( +Add input with bias, then add residual inputs. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + BiasAdd, 1, + OpSchema() + .SetDoc(BiasAdd_ver1_doc) + .Input(0, + "X", + "Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels", + "T") + .Input(1, + "bias", + "Bias tensor. Dimensions are (C)", + "T") + .Input(2, + "skip", + "Residual tensor. Dimensions are (N, S, C)", + "T") + .Output(0, + "Y", + "The output tensor with dimensions (N, S, C)", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index bd8469909fe7f..548f0a7ecc353 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -50,6 +50,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist); @@ -139,6 +140,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 842d5cd943226..53106715cd36c 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -203,6 +203,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, "GroupNorm": self._infer_GroupNorm, "BiasSplitGelu": self._infer_BiasSplitGelu, + "BiasAdd": self._infer_BiasAdd, "NhwcConv": self._infer_NhwcConv, } self.aten_op_dispatcher_ = { @@ -443,6 +444,7 @@ def _onnx_infer_single_node(self, node): "MultiHeadAttention", "GroupNorm", "BiasSplitGelu", + "BiasAdd", "NhwcConv", ] @@ -2104,6 +2106,9 @@ def _infer_BiasSplitGelu(self, node): output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + def _infer_BiasAdd(self, node): + self._propagate_shape_and_type(node) + def _infer_PythonOp(self, node): output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 437e72fce0a31..a7904c39f8491 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -111,7 +111,6 @@ def make_value_info_from_tensor(tensor): "NonMaxSuppression", "TopK", "RoiAlign", - "Resize", "Range", "CumSum", "Min", @@ -120,6 +119,10 @@ def make_value_info_from_tensor(tensor): ] +# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices +ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]} + + class InitializerTracker: """Class for keeping track of initializer.""" @@ -198,6 +201,12 @@ def convert_float_to_float16( queue = [] value_info_list = [] node_list = [] + + # Some operators (Like Resize or GroupNorm) have data type fixed as float for some input. + # When it is converted to float16, there are mixed types: some inputs are float32 and some are float16. + # This list keeps track of such nodes that are not in block list. + mixed_float_type_node_list = [] + # type inference on input model if func_infer_shape is not None: model = func_infer_shape(model) @@ -276,9 +285,11 @@ def convert_float_to_float16( n.output[i] = name_mapping[n.output[i]] is_node_blocked = n.op_type in op_block_list or n.name in node_block_list - for input in n.input: - if input in fp32_initializers: - fp32_initializers[input].add_node(n, is_node_blocked) + for i, input_name in enumerate(n.input): + if input_name in fp32_initializers: + # For Resize/GroupNorm, only the first input can be float16 + use_fp32_weight = is_node_blocked or (n.op_type in ["Resize", "GroupNorm"] and i != 0) + fp32_initializers[input_name].add_node(n, use_fp32_weight) if is_node_blocked: node_list.append(n) @@ -288,8 +299,14 @@ def convert_float_to_float16( if attr.name == "to" and attr.i == 1: attr.i = 10 break - for attr in n.attribute: - next_level.append(attr) + + # For Resize/GroupNorm, attribute data type cannot be changed + if n.op_type not in ["Resize", "GroupNorm"]: + for attr in n.attribute: + next_level.append(attr) + else: + mixed_float_type_node_list.append(n) + # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) # and process node.attribute.t and node.attribute.tensors (TensorProto) if isinstance(q, onnx_proto.AttributeProto): @@ -329,15 +346,36 @@ def convert_float_to_float16( ) ) + # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. + for node in mixed_float_type_node_list: + for i, input_name in enumerate(node.input): + if i not in ALWAYS_FLOAT_INPUTS[node.op_type]: + continue + for value_info in value_info_list: + if input_name == value_info.name: + # create new value_info for current node's new input name + new_value_info = model.graph.value_info.add() + new_value_info.CopyFrom(value_info) + output_name = node.name + "_input_cast_" + str(i) + new_value_info.name = output_name + new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + # add Cast node (from tensor(float16) to tensor(float) before current node + node_name = node.name + "_input_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] + model.graph.node.extend(new_node) + # change current node's input name + node.input[i] = output_name + break + # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, # insert a float16 to float Cast node before the node, # change current node's input name and create new value_info for the new name for i in range(len(node.input)): - input = node.input[i] + input_name = node.input[i] for value_info in value_info_list: - if input == value_info.name: + if input_name == value_info.name: # create new value_info for current node's new input name new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(value_info) @@ -346,7 +384,7 @@ def convert_float_to_float16( new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) - new_node = [helper.make_node("Cast", [input], [output_name], to=1, name=node_name)] + new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 0441ce494d560..361baaedc4a95 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -92,9 +92,6 @@ def create_attention_node( q_matmul (NodeProto): MatMul node in fully connection for Q k_matmul (NodeProto): MatMul node in fully connection for K v_matmul (NodeProto): MatMul node in fully connection for V - q_add (NodeProto): Add bias node in fully connection for Q - k_add (NodeProto): Add bias node in fully connection for K - v_add (NodeProto): Add bias node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. input (str): input name diff --git a/onnxruntime/python/tools/transformers/fusion_attention_vae.py b/onnxruntime/python/tools/transformers/fusion_attention_vae.py new file mode 100644 index 0000000000000..e91a8a61fcc24 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_vae.py @@ -0,0 +1,304 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple, Union + +import numpy as np +from fusion_base import Fusion +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionAttentionVae(Fusion): + """ + Fuse Attention subgraph of Vae Decoder into one Attention node. + """ + + def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int): + super().__init__(model, "Attention", ["Softmax"]) + self.hidden_size = hidden_size + self.num_heads = num_heads + + # Flags to show warning only once + self.num_heads_warning = True + self.hidden_size_warning = True + + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]: + """Detect num_heads and hidden_size from a reshape node. + + Args: + reshape_q (NodeProto): reshape node for Q + add_q (NodeProto): add node for Q + + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + concat = self.model.get_parent(reshape_q, 1) + if concat is None or len(concat.input) != 4: + return self.num_heads, self.hidden_size # Fall back to user specified value + + value = self.model.get_constant_value(concat.input[2]) + if not (value is not None and isinstance(value, np.ndarray) and value.size == 1): + return self.num_heads, self.hidden_size # Fall back to user specified value + num_heads = int(value) + if num_heads <= 0: + return self.num_heads, self.hidden_size # Fall back to user specified value + + _, bias = self.model.get_constant_input(add_q) + if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1: + return self.num_heads, self.hidden_size # Fall back to user specified value + + hidden_size = bias.shape[0] + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning( + "Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads + ) + self.num_heads_warning = False # Do not show the warning more than once + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def create_attention_node( + self, + q_matmul: NodeProto, + q_add: NodeProto, + k_matmul: NodeProto, + k_add: NodeProto, + v_matmul: NodeProto, + v_add: NodeProto, + num_heads: int, + hidden_size: int, + input_name: str, + output_name: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. + + Args: + q_matmul (NodeProto): MatMul node in fully connection for Q + q_add (NodeProto): Add bias node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + k_add (NodeProto): Add bias node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + v_add (NodeProto): Add bias node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + input_name (str): input name + output_name (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name: + logger.debug( + "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) + return None + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) + return None + + q_weight_tensor = self.model.get_initializer(q_matmul.input[1]) + k_weight_tensor = self.model.get_initializer(k_matmul.input[1]) + v_weight_tensor = self.model.get_initializer(v_matmul.input[1]) + if not (q_weight_tensor and k_weight_tensor and v_weight_tensor): + return None + + q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) + k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) + v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) + + q_bias = numpy_helper.to_array(q_bias_tensor) + k_bias = numpy_helper.to_array(k_bias_tensor) + v_bias = numpy_helper.to_array(v_bias_tensor) + + q_bias_shape = np.prod(q_bias.shape) + k_bias_shape = np.prod(k_bias.shape) + v_bias_shape = np.prod(v_bias.shape) + + # Sometimes weights are stored in fp16 + if q_weight_tensor.data_type == 10: + logger.debug("weights are in fp16. Please run fp16 conversion after optimization") + return None + + q_weight = numpy_helper.to_array(q_weight_tensor) + k_weight = numpy_helper.to_array(k_weight_tensor) + v_weight = numpy_helper.to_array(v_weight_tensor) + + # assert q and k have same shape as expected + if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape: + return None + + qw_in_size = q_weight.shape[0] + kw_in_size = k_weight.shape[0] + vw_in_size = v_weight.shape[0] + + assert qw_in_size == kw_in_size and kw_in_size == vw_in_size + + if hidden_size > 0 and hidden_size != qw_in_size: + raise ValueError( + f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). " + "Please provide a correct input hidden size or pass in 0" + ) + + # All the matrices can have the same shape or q, k matrics can have the same shape with v being different + # For 2d weights, the shapes would be [in_size, out_size]. + # For 3d weights, shape would be [in_size, a, b] where a*b = out_size + qw_out_size = np.prod(q_weight.shape[1:]) + + qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1) + qkv_weight_dim = 3 * int(qw_out_size) + + attention_node_name = self.model.create_node_name("Attention") + + assert q_bias_shape == k_bias_shape == v_bias_shape + + qkv_bias_dim = 0 + qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0) + qkv_bias_dim = 3 * q_bias_shape + + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + + # No bias, use zeros + qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) + qkv_bias_dim = 3 * hidden_size + + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias.flatten().tolist(), + ) + self.model.add_initializer(bias, self.this_graph_name) + + attention_inputs = [ + input_name, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] + + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output_name], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + self.increase_counter("Attention (self attention)") + return attention_node + + def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node): + matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False) + if matmul_qkv is None: + return + + reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False) + if reshape_qkv is None: + return + + transpose_qkv = self.model.find_first_child_by_type( + reshape_qkv, "Transpose", input_name_to_nodes, recursive=False + ) + if transpose_qkv is None: + return + + reshape_out = self.model.find_first_child_by_type( + transpose_qkv, "Reshape", input_name_to_nodes, recursive=False + ) + if reshape_out is None: + return + + matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False) + if matmul_out is None: + return + + add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False) + if add_out is None: + return + + transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False) + if transpose_out is None: + return + + v_nodes = self.model.match_parent_path( + matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] + ) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return + (_, _, _, add_v, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) + if qk_nodes is not None: + (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path( + matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None] + ) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return + (_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return + (_, _, _, _, add_k, matmul_k) = k_nodes + + attention_last_node = reshape_out + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + add_q, + matmul_k, + add_k, + matmul_v, + add_v, + q_num_heads, + q_hidden_size, + matmul_q.input[0], + attention_last_node.output[0], + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index f338c0a8c4438..43969aa04bf20 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -52,11 +52,11 @@ def apply(self): if self.fused_count: for key, value in self.fused_count.items(): if value: - logger.info(f"Fused {key} count: {value}") + logger.info(f"Fused {key}: {value}") else: count = op_list.count(self.fused_op_type) if count > 0: - logger.info(f"Fused {self.description} count: {count}") + logger.info(f"Fused {self.description}: {count}") self.model.remove_nodes(self.nodes_to_remove) self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name) diff --git a/onnxruntime/python/tools/transformers/fusion_bias_add.py b/onnxruntime/python/tools/transformers/fusion_bias_add.py new file mode 100644 index 0000000000000..cdf54a3629726 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_bias_add.py @@ -0,0 +1,58 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict + +from fusion_base import Fusion +from numpy import ndarray +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionBiasAdd(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "BiasAdd", "Add") + + def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse Add bias and Add skip connection into BiasAdd + """ + + nodes = self.model.match_parent_path( + add_node, + ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"], + [0, None, 0, 0, 0], + output_name_to_node, + ) + + if nodes is None: + return + + bias_node = nodes[0] + skip_layer_norm = nodes[-1] + + # Check skip connection is from SkipLayerNormalization output + if not (add_node.input[1] in skip_layer_norm.output): + return + + bias_index, bias_value = self.model.get_constant_input(bias_node) + if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)): + return + if bias_value.ndim != 1: + return + + self.nodes_to_remove.extend([add_node, bias_node]) + node_name = self.model.create_node_name("BiasAdd") + fused_node = helper.make_node( + "BiasAdd", + inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]], + outputs=[add_node.output[0]], + name=node_name, + ) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[node_name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index a0a4d7c16de0b..c04c806eaaa86 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -7,7 +7,6 @@ import numpy as np from fusion_base import Fusion -from fusion_utils import FusionUtils from onnx import TensorProto, helper from onnx_model import OnnxModel @@ -143,35 +142,22 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): # instance_norm_scale might from Constant node. Use prune graph to clear it. self.prune_graph = True - # Right now GroupNorm only support float16 input. Need add a Cast in fp32 model. - utils = FusionUtils(self.model) - - input = root - output = last_node.output[0] - if weight.dtype == np.float32: - # Add a Cast node to get float16 input for GroupNorm - cast_input, _cast_node = utils.cast_input(root, "float16") - input = cast_input - - # Add a Cast node to convert back to float32 after GroupNorm - output = group_norm_name + "_out" - cast_node = helper.make_node("Cast", inputs=[group_norm_name + "_out"], outputs=[last_node.output[0]]) - cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))]) - self.model.add_node(cast_node) + input_name = root + output_name = last_node.output[0] # NCHW to NHWC transpose_input = helper.make_node( "Transpose", - [input], - [input + "_NHWC"], + [input_name], + [input_name + "_NHWC"], name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), perm=[0, 2, 3, 1], ) new_node = helper.make_node( "GroupNorm", - inputs=[input + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"], - outputs=[output + "_NHWC"], + inputs=[input_name + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[output_name + "_NHWC"], name=group_norm_name, ) @@ -183,8 +169,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): # NHWC to NCHW transpose_output = helper.make_node( "Transpose", - [output + "_NHWC"], - [output], + [output_name + "_NHWC"], + [output_name], name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), perm=[0, 3, 1, 2], ) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index cdfa2c626fc57..7629a6886ab94 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -51,9 +51,10 @@ def __init__(self, model_type): ) # options for stable diffusion - self.enable_group_norm = model_type == "unet" - self.enable_bias_splitgelu = model_type == "unet" - self.enable_packed_kv = model_type == "unet" + self.enable_group_norm = model_type in ["unet", "vae"] + self.enable_bias_splitgelu = model_type in ["unet"] + self.enable_packed_kv = model_type in ["unet"] + self.enable_bias_add = model_type in ["unet"] def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index d92ddd5f8e678..84961f799a122 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -25,6 +25,9 @@ def fuse( output_name_to_node: Dict[str, NodeProto], ): """ + Note that onnxruntime will do comprehensive transpose optimization after loading model. + The purpose of this fusion is to make graph clean before running onnxruntime. + Case 1: (input)-->Transpose(perm=a)-->Transpose(perm=b)--> After: @@ -38,8 +41,6 @@ def fuse( (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore) | +----->Cast --> Transpose(perm=a*b)--> - - """ transpose_b = transpose_node if transpose_b.input[0] not in output_name_to_node: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 07fdf490337a4..0945be6cc6898 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -74,7 +74,23 @@ def remove_cast_int32(self, input_name: str): self.model.replace_input_of_all_nodes(output_name, input_name) @staticmethod - def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes): + def update_node_input(node, i, new_input_name, input_name_to_nodes): + old_input_reference = 0 + if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]: + input_name_to_nodes[node.input[i]].remove(node) + old_input_reference = len(input_name_to_nodes[node.input[i]]) + + node.input[i] = new_input_name + + if new_input_name in input_name_to_nodes: + input_name_to_nodes[new_input_name].append(node) + else: + input_name_to_nodes[new_input_name] = [node] + + return old_input_reference + + @staticmethod + def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0): """ Before: (input)-->parent-->node-->(output) @@ -83,20 +99,16 @@ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes): | +----->node-->(output) - This function returns a flag about whether the parent node can be removed. - Note that this function assumes the node has first input links from parent! + This function returns a flag whether the parent node can be removed. """ - parent_can_be_removed = False - input_name_to_nodes[node.input[0]].remove(node) + + old_input_name = node.input[node_input_index] + new_input_name = parent_node.input[parent_input_index] + old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes) + # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore. - if len(input_name_to_nodes[node.input[0]]) == 0 and not model.find_graph_output( - node.input[0] - ): # checks main graph output. TODO: deal with subgraph - parent_can_be_removed = True - # self.nodes_to_remove.append(transpose_a) - - input_name_to_nodes[parent_node.input[0]].append(node) - node.input[0] = parent_node.input[0] + parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name) + return parent_can_be_removed @staticmethod diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 9a00dc8684f32..236923d5ff496 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -14,10 +14,7 @@ } -def get_test_settings(): - height = 512 - width = 512 - num_inference_steps = 50 +def example_prompts(): prompts = [ "a photo of an astronaut riding a horse on mars", "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", @@ -31,11 +28,11 @@ def get_test_settings(): "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", ] - return height, width, num_inference_steps, prompts + return prompts def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_safety_checker: bool): - from diffusers import OnnxStableDiffusionPipeline + from diffusers import DPMSolverMultistepScheduler, OnnxStableDiffusionPipeline import onnxruntime @@ -54,6 +51,8 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf provider=provider, use_auth_token=True, ) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.set_progress_bar_config(disable=True) if disable_safety_checker: pipe.safety_checker = None @@ -63,14 +62,15 @@ def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_saf def get_torch_pipeline(model_name: str, disable_safety_checker: bool): - from diffusers import StableDiffusionPipeline + from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline from torch import channels_last, float16 pipe = StableDiffusionPipeline.from_pretrained( model_name, torch_dtype=float16, revision="fp16", use_auth_token=True ).to("cuda") - pipe.unet.to(memory_format=channels_last) # in-place operation + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.set_progress_bar_config(disable=True) if disable_safety_checker: pipe.safety_checker = None @@ -84,66 +84,111 @@ def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, dis return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe") -def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str): +def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count): from diffusers import OnnxStableDiffusionPipeline assert isinstance(pipe, OnnxStableDiffusionPipeline) - height, width, num_inference_steps, prompts = get_test_settings() + prompts = example_prompts() - pipe("warm up", height, width, num_inference_steps=2) + pipe("warm up", height, width, num_inference_steps=steps) latency_list = [] for i, prompt in enumerate(prompts): - input_prompts = [prompt] * batch_size - inference_start = time.time() - image = pipe(input_prompts, height, width, num_inference_steps).images[0] - inference_end = time.time() - - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency} seconds") - image.save(f"{image_filename_prefix}_{i}.jpg") + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + images = pipe( + prompt, + height, + width, + num_inference_steps=steps, + negative_prompt=None, + guidance_scale=7.5, + num_images_per_prompt=batch_size, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + print("Average latency in seconds:", sum(latency_list) / len(latency_list)) -def run_torch_pipeline(pipe, batch_size: int, image_filename_prefix: str): +def run_torch_pipeline( + pipe, batch_size: int, image_filename_prefix: str, height, width, steps, num_prompts, batch_count +): import torch - height, width, num_inference_steps, prompts = get_test_settings() + prompts = example_prompts() - pipe("warm up", height, width, num_inference_steps=2) + pipe("warm up", height, width, num_inference_steps=steps) torch.set_grad_enabled(False) latency_list = [] for i, prompt in enumerate(prompts): - input_prompts = [prompt] * batch_size - torch.cuda.synchronize() - inference_start = time.time() - image = pipe(input_prompts, height, width, num_inference_steps).images[0] + if i >= num_prompts: + break torch.cuda.synchronize() - inference_end = time.time() - - latency = inference_end - inference_start - latency_list.append(latency) - print(f"Inference took {latency} seconds") - image.save(f"{image_filename_prefix}_{i}.jpg") + for j in range(batch_count): + inference_start = time.time() + images = pipe( + prompt=prompt, + height=height, + width=width, + num_inference_steps=steps, + guidance_scale=7.5, + negative_prompt=None, + num_images_per_prompt=batch_size, + generator=None, # torch.Generator + ).images + + torch.cuda.synchronize() + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") print("Average latency in seconds:", sum(latency_list) / len(latency_list)) -def run_ort(model_name: str, directory: str, provider: str, batch_size: int, disable_safety_checker: bool): +def run_ort( + model_name: str, + directory: str, + provider: str, + batch_size: int, + disable_safety_checker: bool, + height, + width, + steps, + num_prompts, + batch_count, +): load_start = time.time() pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker) load_end = time.time() print(f"Model loading took {load_end - load_start} seconds") image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) - run_ort_pipeline(pipe, batch_size, image_filename_prefix) - - -def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool): + run_ort_pipeline(pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count) + + +def run_torch( + model_name: str, + batch_size: int, + disable_safety_checker: bool, + height, + width, + steps, + num_prompts, + batch_count, +): import torch torch.backends.cudnn.enabled = True @@ -159,7 +204,7 @@ def run_torch(model_name: str, batch_size: int, disable_safety_checker: bool): image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) with torch.inference_mode(): - run_torch_pipeline(pipe, batch_size, image_filename_prefix) + run_torch_pipeline(pipe, batch_size, image_filename_prefix, height, width, steps, num_prompts, batch_count) def parse_arguments(): @@ -201,7 +246,58 @@ def parse_arguments(): ) parser.set_defaults(enable_safety_checker=False) - parser.add_argument("-b", "--batch_size", type=int, default=1) + parser.add_argument( + "-b", + "--batch_size", + type=int, + default=1, + choices=[1, 2, 4, 8, 16, 32], + help="Number of images per batch", + ) + + parser.add_argument( + "--height", + required=False, + type=int, + default=512, + help="Output image height", + ) + + parser.add_argument( + "--width", + required=False, + type=int, + default=512, + help="Output image width", + ) + + parser.add_argument( + "-s", + "--steps", + required=False, + type=int, + default=50, + help="Number of steps", + ) + + parser.add_argument( + "-n", + "--num_prompts", + required=False, + type=int, + default=1, + help="Number of prompts", + ) + + parser.add_argument( + "-c", + "--batch_count", + required=False, + type=int, + choices=range(1, 11), + default=10, + help="Number of batches to test", + ) args = parser.parse_args() return args @@ -219,13 +315,33 @@ def main(): # Need remove a line https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L307 # in diffuers to run batch_size > 1. assert ( - args.enable_safety_checker + not args.enable_safety_checker ), "batch_size > 1 is not compatible with safety checker due to a bug in diffuers" provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers - run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker) + run_ort( + sd_model, + args.pipeline, + provider, + args.batch_size, + not args.enable_safety_checker, + args.height, + args.width, + args.steps, + args.num_prompts, + args.batch_count, + ) else: - run_torch(sd_model, args.batch_size, not args.enable_safety_checker) + run_torch( + sd_model, + args.batch_size, + not args.enable_safety_checker, + args.height, + args.width, + args.steps, + args.num_prompts, + args.batch_count, + ) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 932be4a19ae6b..31b2a22c2f615 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -10,35 +10,54 @@ # pip install -r requirements.txt # huggingface-cli login # wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py -# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32 +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/sd-v1-5 +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/sd-v2-1 # Note that this script might not be compatible with older or newer version of diffusers. # Then you can use this script to convert them to float16 like the following: -# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python optimize_pipeline.py -i $ONNX_ROOT/sd-v1-5 -o $ONNX_ROOT/sd-v1-5-fp16 --float16 # Or -# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/sd-v1-5 -o $ONNX_ROOT/sd-v1-5-fp16 --float16 # # Note that output model is for CUDA Execution Provider. It might not run in CPU Execution Provider. -# Stable diffusion 2.1 model will get black images using float16 Attention. It is a known issue that we are working on. +# +# Stable diffusion 2.1 model will get black images using float16 Attention. A walkaround is to force it in float32: +# python optimize_pipeline.py -i $ONNX_ROOT/sd-v2-1 -o $ONNX_ROOT/sd-v2-1-fp16 --float16 --force_fp32_ops unet:Attention + import argparse import logging import os import shutil import sys +import tempfile from pathlib import Path +from typing import List import coloredlogs +import onnx +from packaging import version + +import onnxruntime sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from fusion_options import FusionOptions -from optimizer import optimize_model # noqa: E402 +from onnx_model_clip import ClipOnnxModel +from onnx_model_unet import UnetOnnxModel +from onnx_model_vae import VaeOnnxModel +from optimizer import optimize_by_onnxruntime, optimize_model logger = logging.getLogger(__name__) def optimize_sd_pipeline( - source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool + source_dir: Path, + target_dir: Path, + overwrite: bool, + use_external_data_format: bool, + float16: bool, + force_fp32_ops: List[str], + enable_runtime_optimization: bool, ): """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16. @@ -48,54 +67,115 @@ def optimize_sd_pipeline( overwrite (bool): Overwrite files if exists. use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights float16 (bool): use half precision + force_fp32_ops(List[str]): operators that are forced to run in float32. + enable_runtime_optimization(bool): run graph optimization using Onnx Runtime. Raises: RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ - dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] - for name in dirs_with_onnx: + model_type_mapping = { + "unet": "unet", + "vae_encoder": "vae", + "vae_decoder": "vae", + "text_encoder": "clip", + "safety_checker": "unet", + } + + model_type_class_mapping = { + "unet": UnetOnnxModel, + "vae": VaeOnnxModel, + "clip": ClipOnnxModel, + } + + force_fp32_operators = { + "unet": [], + "vae_encoder": [], + "vae_decoder": [], + "text_encoder": [], + "safety_checker": [], + } + + if force_fp32_ops: + for fp32_operator in force_fp32_ops: + parts = fp32_operator.split(":") + if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()): + force_fp32_operators[parts[0]].append(parts[1]) + else: + raise ValueError( + f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}" + ) + + for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): message = f"input onnx model does not exist: {onnx_model_path}." - if name not in ["safety_checker", "feature_extractor"]: + if name not in ["safety_checker"]: raise RuntimeError(message) continue + # Prepare output directory + optimized_model_path = target_dir / name / "model.onnx" + output_dir = optimized_model_path.parent + if optimized_model_path.exists(): + if not overwrite: + raise RuntimeError(f"output onnx model path existed: {optimized_model_path}") + + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Graph fusion before fp16 conversion, otherwise they cannot be fused later. # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. - logger.info(f"optimize {onnx_model_path}...") + logger.info(f"Optimize {onnx_model_path}...") + + # There are some optimizations that are not avaiable in v1.14 or older version + has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0") - fusion_options = FusionOptions("unet") + fusion_options = FusionOptions(model_type) fusion_options.enable_packed_kv = float16 + fusion_options.enable_bias_add = has_all_optimizations m = optimize_model( str(onnx_model_path), - model_type="unet", + model_type=model_type, num_heads=0, # will be deduced from graph hidden_size=0, # will be deduced from graph opt_level=0, optimization_options=fusion_options, - use_gpu=False, + use_gpu=True, ) if float16: - logger.info("convert %s to float16 ...", name) - m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"]) - - optimized_model_path = target_dir / name / "model.onnx" - output_dir = optimized_model_path.parent - if optimized_model_path.exists(): - if not overwrite: - raise RuntimeError(f"output onnx model path existed: {optimized_model_path}") - - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - + logger.info("Convert %s to float16 ...", name) + op_block_list = ["RandomNormalLike"] + m.convert_float_to_float16( + keep_io_types=False, + op_block_list=op_block_list + force_fp32_operators[name], + ) + + if enable_runtime_optimization and (float16 or (name not in ["unet"])): + # Use this step to see the final graph that executed by Onnx Runtime. + # Note that ORT cannot save model larger than 2GB so we exclude unet float32 model. + # This step is optional since it has no impact on performance except model loading time. + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") + tmp_model_path = Path(tmp_dir) / "model.onnx" + m.save_model_to_file(str(tmp_model_path)) + ort_optimized_model_path = tmp_model_path + optimize_by_onnxruntime( + str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + ) + model = onnx.load(str(ort_optimized_model_path), load_external_data=True) + m = model_type_class_mapping[model_type](model) + + m.get_operator_statistics() + m.get_fused_operator_statistics() m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) - logger.info("%s => %s", onnx_model_path, optimized_model_path) + logger.info("%s is optimized", name) + logger.info("*" * 20) def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): @@ -117,7 +197,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): if not os.path.exists(source_path): message = f"source path does not exist: {source_path}" - if name not in ["safety_checker", "feature_extractor"]: + if name not in ["feature_extractor"]: raise RuntimeError(message) continue @@ -177,6 +257,22 @@ def parse_arguments(): ) parser.set_defaults(float16=False) + parser.add_argument( + "--force_fp32_ops", + required=False, + nargs="+", + type=str, + help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!", + ) + + parser.add_argument( + "--inspect", + required=False, + action="store_true", + help="Inspect the optimized graph from Onnx Runtime for debugging purpose. This option has no impact on model performance.", + ) + parser.set_defaults(inspect=False) + parser.add_argument( "--overwrite", required=False, @@ -202,9 +298,16 @@ def parse_arguments(): def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") args = parse_arguments() + logger.info("Arguments: %s", str(args)) copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) optimize_sd_pipeline( - Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16 + Path(args.input), + Path(args.output), + args.overwrite, + args.use_external_data_format, + args.float16, + args.force_fp32_ops, + args.inspect, ) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 276a9428ecf72..c8380c93f591c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -802,11 +802,14 @@ def prune_graph(self, outputs=None): self.model.graph.input.remove(input) if input_to_remove or output_to_remove or nodes_to_remove: - logger.info( - "Graph pruned: {} inputs, {} outputs and {} nodes are removed".format( - len(input_to_remove), len(output_to_remove), len(nodes_to_remove) - ) - ) + removed = [] + if input_to_remove: + removed.append(f"{len(input_to_remove)} inputs") + if output_to_remove: + removed.append(f"{len(output_to_remove)} outputs") + if nodes_to_remove: + removed.append(f"{len(nodes_to_remove)} nodes") + logger.info("Removed %s", ", ".join(removed)) self.update_graph() @@ -1022,6 +1025,18 @@ def get_opset_version(self): return opset.version raise RuntimeError("ONNX model has no opset for default domain") + def get_operator_statistics(self, include_domain=False): + """ + Returns node count of operators. + """ + op_count = {} + for node in self.nodes(): + op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type + op_count[op] = 1 if op not in op_count else (op_count[op] + 1) + + logger.info(f"Operators:{op_count}") + return op_count + @staticmethod def has_same_value(tensor1: TensorProto, tensor2: TensorProto) -> bool: """Returns True when two tensors have same value. diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py new file mode 100644 index 0000000000000..93e8623768067 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger + +from onnx import ModelProto +from onnx_model_unet import UnetOnnxModel + +logger = getLogger(__name__) + + +class ClipOnnxModel(UnetOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "LayerNormalization", + "SkipLayerNormalization", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 32a98149825c3..b460d7d8a7f8f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -7,6 +7,7 @@ from typing import Optional from fusion_attention_unet import FusionAttentionUnet +from fusion_bias_add import FusionBiasAdd from fusion_biassplitgelu import FusionBiasSplitGelu from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv @@ -36,7 +37,6 @@ def preprocess(self): self.remove_useless_div() def postprocess(self): - self.merge_sequential_transpose() self.prune_graph() self.remove_unused_constant() @@ -54,14 +54,14 @@ def remove_useless_div(self): if nodes_to_remove: self.remove_nodes(nodes_to_remove) - logger.info("Removed %d useless Div (by 1) nodes", len(nodes_to_remove)) + logger.info("Removed %d Div nodes", len(nodes_to_remove)) def convert_conv_to_nhwc(self): # Do not update weight here since save external data has a bug conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False) conv_to_nhwc_conv.apply() - def merge_sequential_transpose(self): + def merge_adjacent_transpose(self): fusion_transpose = FusionTranspose(self) fusion_transpose.apply() @@ -89,6 +89,20 @@ def merge_sequential_transpose(self): if total: logger.info("Removed %d Transpose nodes", total) + def fuse_attention(self, options: Optional[FusionOptions] = None): + # Self Attention + self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False) + self_attention_fusion.apply() + + # Cross Attention + enable_packed_kv = (options is None) or options.enable_packed_kv + cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv) + cross_attention_fusion.apply() + + def fuse_bias_add(self): + fusion = FusionBiasAdd(self) + fusion.apply() + def optimize(self, options: Optional[FusionOptions] = None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() @@ -117,12 +131,7 @@ def optimize(self, options: Optional[FusionOptions] = None): bias_split_gelu_fusion.apply() if (options is None) or options.enable_attention: - self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False) - self_attention_fusion.apply() - - enable_packed_kv = (options is None) or options.enable_packed_kv - cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv) - cross_attention_fusion.apply() + self.fuse_attention() if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() @@ -141,6 +150,10 @@ def optimize(self, options: Optional[FusionOptions] = None): if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + self.merge_adjacent_transpose() + + self.fuse_bias_add() + self.postprocess() logger.info(f"opset version: {self.get_opset_version()}") @@ -153,8 +166,6 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "MultiHeadAttention", - "Gelu", - "FastGelu", "LayerNormalization", "SkipLayerNormalization", "BiasSplitGelu", diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py new file mode 100644 index 0000000000000..47d3f9ddfe6fd --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import Optional + +from fusion_attention_vae import FusionAttentionVae +from fusion_options import FusionOptions +from onnx import ModelProto +from onnx_model_unet import UnetOnnxModel + +logger = getLogger(__name__) + + +class VaeOnnxModel(UnetOnnxModel): + def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): + assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) + super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + + def fuse_attention(self, options: Optional[FusionOptions] = None): + # Self Attention + self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads) + self_attention_fusion.apply() + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "GroupNorm", + "NhwcConv", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators:{op_count}") + return op_count diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index a18535c10591d..65de1bf770992 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -29,10 +29,12 @@ from onnx_model_bert import BertOnnxModel from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF +from onnx_model_clip import ClipOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel +from onnx_model_vae import VaeOnnxModel logger = logging.getLogger(__name__) @@ -49,8 +51,11 @@ 0, ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), - "unet": (UnetOnnxModel, "pytorch", 1), "t5": (T5OnnxModel, "pytorch", 2), + # Stable Diffusion models + "unet": (UnetOnnxModel, "pytorch", 1), + "vae": (VaeOnnxModel, "pytorch", 1), + "clip": (ClipOnnxModel, "pytorch", 1), } @@ -152,7 +157,7 @@ def optimize_by_fusion( Returns: object of an optimizer class. """ - if model_type not in ["bert", "unet"] and (num_heads == 0 or hidden_size == 0): + if model_type not in ["bert", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0): logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}") (optimizer_class, producer, _) = MODEL_TYPES[model_type] @@ -446,10 +451,11 @@ def main(): optimizer.save_model_to_file(args.output, args.use_external_data_format) - if optimizer.is_fully_optimized(): - logger.info("The model has been fully optimized.") - else: - logger.info("The model has been optimized.") + if args.model_type in ["bert", "gpt2"]: + if optimizer.is_fully_optimized(): + logger.info("The model has been fully optimized.") + else: + logger.info("The model has been optimized.") if __name__ == "__main__": diff --git a/onnxruntime/test/contrib_ops/bias_add_op_test.cc b/onnxruntime/test/contrib_ops/bias_add_op_test.cc new file mode 100644 index 0000000000000..733ef81999a10 --- /dev/null +++ b/onnxruntime/test/contrib_ops/bias_add_op_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime::test; + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) // The operator has only CUDA implementation right now +static std::vector GetExpectedResult(const std::vector& input_data, + const std::vector& bias_data, + const std::vector& skip_data) { + std::vector output_data; + output_data.reserve(input_data.size()); + + size_t bias_length = bias_data.size(); + for (size_t i = 0; i < input_data.size(); i++) { + output_data.push_back(input_data[i] + bias_data[i % bias_length] + skip_data[i]); + } + return output_data; +} + +static void RunSkipBiasGpuTest(const std::vector& input_data, + const std::vector& bias_data, + const std::vector& skip_data, + const std::vector& output_data, + const std::vector& input_dims, + const std::vector& bias_dims, + const std::vector& skip_dims, + const std::vector& output_dims, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + if (!HasCudaEnvironment(min_cuda_architecture)) { + return; + } + + OpTester tester("BiasAdd", 1, onnxruntime::kMSDomain); + + if (use_float16) { + tester.AddInput("X", input_dims, ToFloat16(input_data)); + tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); + tester.AddInput("skip", skip_dims, ToFloat16(skip_data)); + tester.AddOutput("Y", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("X", input_dims, input_data); + tester.AddInput("bias", bias_dims, bias_data); + tester.AddInput("skip", skip_dims, skip_data); + tester.AddOutput("Y", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunBiasAddTest(int64_t batch_size, int64_t image_size, int64_t num_channels) { + std::vector input_dims = {batch_size, image_size, num_channels}; + std::vector bias_dims = {num_channels}; + std::vector& skip_dims = input_dims; + std::vector& output_dims = input_dims; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.3f); + std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); + std::vector skip_data = random.Gaussian(skip_dims, 0.0f, 0.3f); + std::vector output_data = GetExpectedResult(input_data, bias_data, skip_data); + + RunSkipBiasGpuTest(input_data, bias_data, skip_data, output_data, input_dims, bias_dims, skip_dims, output_dims); +} + +TEST(BiasAddTest, BiasAddTest_HiddenSize_320) { + constexpr int64_t batch_size = 2; + constexpr int64_t image_size = 5; + constexpr int64_t num_channels = 320; + RunBiasAddTest(batch_size, image_size, num_channels); +} + +TEST(BiasAddTest, BiasAddTest_HiddenSize_640) { + constexpr int64_t batch_size = 2; + constexpr int64_t image_size = 1; + constexpr int64_t num_channels = 640; + RunBiasAddTest(batch_size, image_size, num_channels); +} + +TEST(BiasAddTest, BiasAddTest_HiddenSize_1280) { + constexpr int64_t batch_size = 1; + constexpr int64_t image_size = 2; + constexpr int64_t num_channels = 1280; + RunBiasAddTest(batch_size, image_size, num_channels); +} +#endif + +} // namespace test +} // namespace onnxruntime