Skip to content

Commit

Permalink
Stable Diffusion CUDA Optimizations Part 3 (#14646)
Browse files Browse the repository at this point in the history
The third part for stable diffusion CUDA optimizations
(1) Add BiasAdd operator to replace two Add (bias and residual); Add
fusion for BiasAdd
(2) Add Attention fusion for VAE decoder.
(3) Update float16 conversion to handle Resize and GroupNorm. This could
reduce two Cast nodes for each Resize op in fp16 model.
(4) Force inputs and outputs to be float16 to avoid data casts in the
pipeline.
(5) Add options --force_fp32_ops, --inspect etc in optimize script so that
user could force some operator to run in float32 to potentially get
better image quality (with cost of performance).

Performance tests show slight improvement in T4. Average latency reduced
0.1 seconds (from 5.35s to 5.25s) for 512x512 in 50 steps.
  • Loading branch information
tianleiwu authored Feb 14, 2023
1 parent 6eeeecf commit f638c5a
Show file tree
Hide file tree
Showing 28 changed files with 1,355 additions and 253 deletions.
35 changes: 35 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Do not modify directly.*
* <a href="#com.microsoft.Attention">com.microsoft.Attention</a>
* <a href="#com.microsoft.AttnLSTM">com.microsoft.AttnLSTM</a>
* <a href="#com.microsoft.BeamSearch">com.microsoft.BeamSearch</a>
* <a href="#com.microsoft.BiasAdd">com.microsoft.BiasAdd</a>
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
Expand Down Expand Up @@ -468,6 +469,40 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.BiasAdd"></a><a name="com.microsoft.biasadd">**com.microsoft.BiasAdd**</a>

Add input with bias, then add residual inputs.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor. Dimensions are (C)</dd>
<dt><tt>skip</tt> : T</dt>
<dd>Residual tensor. Dimensions are (N, S, C)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor with dimensions (N, S, C)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


### <a name="com.microsoft.BiasDropout"></a><a name="com.microsoft.biasdropout">**com.microsoft.BiasDropout**</a>

output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models.
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
204 changes: 104 additions & 100 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/diffusion/bias_add.h"
#include "contrib_ops/cuda/diffusion/bias_add_impl.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BiasAdd, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
BiasAdd<T>);

REGISTER_KERNEL_TYPED(MLFloat16);
REGISTER_KERNEL_TYPED(float);

using namespace ONNX_NAMESPACE;

template <typename T>
BiasAdd<T>::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) {
}

template <typename T>
Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const {
// Input: [batch_size, height*width, channels]
// Bias: [channels]
// Skip: [batch_size, height*width, channels]
// Output: [batch_size, height*width, channels]

const Tensor* input = context->Input<Tensor>(0);

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The input is expected to have 3 dimensions, got ", input_dims.size());
}

if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels should be 320, 640 or 1280, got ", input_dims[2]);
}

const Tensor* bias = context->Input<Tensor>(1);
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The bias is expected to have 1 dimensions, got ", bias_dims.size());
}
if (bias_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in the last dimension of input and bias are not the same");
}

const Tensor* skip = context->Input<Tensor>(2);
if (skip->Shape() != input->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Shape of input and skip (residual) shall be the same");
}

Tensor* output = context->Output(0, input->Shape());

typedef typename ToCudaType<T>::MappedType CudaT;
const int32_t grid_size = static_cast<int32_t>(input_dims[0] * input_dims[1]);
LaunchBiasAddKernel<CudaT>(Stream(context), grid_size, static_cast<int32_t>(input_dims[2]),
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<CudaT*>(output->MutableData<T>()));

CUDA_RETURN_IF_ERROR(cudaPeekAtLastError());
return Status::OK();
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

using namespace onnxruntime::cuda;

template <typename T>
class BiasAdd final : public CudaKernel {
public:
BiasAdd(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
79 changes: 79 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// The CUDA kernel is modified from SeqLen2Spatial plugin of TensorRT 8.5.
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cub/cub.cuh>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/bias_add_impl.h"

using namespace onnxruntime::cuda;

namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T, int32_t C, int32_t TPB>
__global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, T* output) {
int32_t base_offset = blockIdx.x * C + threadIdx.x;
int32_t bias_offset = threadIdx.x;

#pragma unroll
for (int32_t i = 0; i < C / TPB; ++i) {
output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset];
base_offset += TPB;
bias_offset += TPB;
}
}

template __global__ void BiasAddKernel<float, 320, 320>(float const*, float const*, float const*, float*);
template __global__ void BiasAddKernel<float, 640, 320>(float const*, float const*, float const*, float*);
template __global__ void BiasAddKernel<float, 1280, 320>(float const*, float const*, float const*, float*);
template __global__ void BiasAddKernel<half, 320, 320>(half const*, half const*, half const*, half*);
template __global__ void BiasAddKernel<half, 640, 320>(half const*, half const*, half const*, half*);
template __global__ void BiasAddKernel<half, 1280, 320>(half const*, half const*, half const*, half*);

template <typename T>
void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
T const* input, T const* bias, T const* residual, T* output) {
constexpr int32_t TPB = 320; // thread per block
switch (num_channels) {
case 320:
(BiasAddKernel<T, 320, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
break;
case 640:
(BiasAddKernel<T, 640, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
break;
case 1280:
(BiasAddKernel<T, 1280, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
}

template void LaunchBiasAddKernel<float>(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
float const* input, float const* bias, float const* residual, float* output);

template void LaunchBiasAddKernel<half>(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
half const* input, half const* bias, half const* residual, half* output);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
19 changes: 19 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/common/status.h"
#include <cuda.h>

namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T>
void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
T const* input, T const* bias, T const* residual, T* output);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h"

using namespace onnxruntime::cuda;

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand All @@ -35,13 +37,9 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) {

#pragma unroll
for (int32_t i = 0; i < HHS / TPB; ++i) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
auto value_left = (float)(input[index_input] + bias[index_bias]);
auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]);
#else
auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]);
auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]);
#endif

// Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0)
float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f);
float result = value_left * gelu_right;
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,32 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
updateOutputShape(ctx, 0, output_shape);
}
}));

constexpr const char* BiasAdd_ver1_doc = R"DOC(
Add input with bias, then add residual inputs.
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
BiasAdd, 1,
OpSchema()
.SetDoc(BiasAdd_ver1_doc)
.Input(0,
"X",
"Input tensor. Dimensions are (N, S, C), where N is the batch size, S is image size H*W, and C is number of channels",
"T")
.Input(1,
"bias",
"Bias tensor. Dimensions are (C)",
"T")
.Input(2,
"skip",
"Residual tensor. Dimensions are (N, S, C)",
"T")
.Output(0,
"Y",
"The output tensor with dimensions (N, S, C)",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist);
Expand Down Expand Up @@ -139,6 +140,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasAdd)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist)>());
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"BiasAdd": self._infer_BiasAdd,
"NhwcConv": self._infer_NhwcConv,
}
self.aten_op_dispatcher_ = {
Expand Down Expand Up @@ -443,6 +444,7 @@ def _onnx_infer_single_node(self, node):
"MultiHeadAttention",
"GroupNorm",
"BiasSplitGelu",
"BiasAdd",
"NhwcConv",
]

Expand Down Expand Up @@ -2104,6 +2106,9 @@ def _infer_BiasSplitGelu(self, node):
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))

def _infer_BiasAdd(self, node):
self._propagate_shape_and_type(node)

def _infer_PythonOp(self, node):
output_tensor_types = get_attribute(node, "output_tensor_types")
assert output_tensor_types
Expand Down
Loading

0 comments on commit f638c5a

Please sign in to comment.