-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable Diffusion CUDA Optimizations Part 3 (#14646)
The third part for stable diffusion CUDA optimizations (1) Add BiasAdd operator to replace two Add (bias and residual); Add fusion for BiasAdd (2) Add Attention fusion for VAE decoder. (3) Update float16 conversion to handle Resize and GroupNorm. This could reduce two Cast nodes for each Resize op in fp16 model. (4) Force inputs and outputs to be float16 to avoid data casts in the pipeline. (5) Add options --force_fp32_ops, --inspect etc in optimize script so that user could force some operator to run in float32 to potentially get better image quality (with cost of performance). Performance tests show slight improvement in T4. Average latency reduced 0.1 seconds (from 5.35s to 5.25s) for 512x512 in 50 steps.
- Loading branch information
Showing
28 changed files
with
1,355 additions
and
253 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/cuda/cuda_common.h" | ||
#include "contrib_ops/cuda/diffusion/bias_add.h" | ||
#include "contrib_ops/cuda/diffusion/bias_add_impl.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
#define REGISTER_KERNEL_TYPED(T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
BiasAdd, \ | ||
kMSDomain, \ | ||
1, \ | ||
T, \ | ||
kCudaExecutionProvider, \ | ||
(*KernelDefBuilder::Create()) \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
BiasAdd<T>); | ||
|
||
REGISTER_KERNEL_TYPED(MLFloat16); | ||
REGISTER_KERNEL_TYPED(float); | ||
|
||
using namespace ONNX_NAMESPACE; | ||
|
||
template <typename T> | ||
BiasAdd<T>::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) { | ||
} | ||
|
||
template <typename T> | ||
Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const { | ||
// Input: [batch_size, height*width, channels] | ||
// Bias: [channels] | ||
// Skip: [batch_size, height*width, channels] | ||
// Output: [batch_size, height*width, channels] | ||
|
||
const Tensor* input = context->Input<Tensor>(0); | ||
|
||
const auto& input_dims = input->Shape().GetDims(); | ||
if (input_dims.size() != 3) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"The input is expected to have 3 dimensions, got ", input_dims.size()); | ||
} | ||
|
||
if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"Number of channels should be 320, 640 or 1280, got ", input_dims[2]); | ||
} | ||
|
||
const Tensor* bias = context->Input<Tensor>(1); | ||
const auto& bias_dims = bias->Shape().GetDims(); | ||
if (bias_dims.size() != 1) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"The bias is expected to have 1 dimensions, got ", bias_dims.size()); | ||
} | ||
if (bias_dims[0] != input_dims[2]) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"Number of channels in the last dimension of input and bias are not the same"); | ||
} | ||
|
||
const Tensor* skip = context->Input<Tensor>(2); | ||
if (skip->Shape() != input->Shape()) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"Shape of input and skip (residual) shall be the same"); | ||
} | ||
|
||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
typedef typename ToCudaType<T>::MappedType CudaT; | ||
const int32_t grid_size = static_cast<int32_t>(input_dims[0] * input_dims[1]); | ||
LaunchBiasAddKernel<CudaT>(Stream(context), grid_size, static_cast<int32_t>(input_dims[2]), | ||
reinterpret_cast<const CudaT*>(input->Data<T>()), | ||
reinterpret_cast<const CudaT*>(bias->Data<T>()), | ||
reinterpret_cast<const CudaT*>(skip->Data<T>()), | ||
reinterpret_cast<CudaT*>(output->MutableData<T>())); | ||
|
||
CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/providers/cuda/cuda_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
using namespace onnxruntime::cuda; | ||
|
||
template <typename T> | ||
class BiasAdd final : public CudaKernel { | ||
public: | ||
BiasAdd(const OpKernelInfo& op_kernel_info); | ||
Status ComputeInternal(OpKernelContext* context) const override; | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// The CUDA kernel is modified from SeqLen2Spatial plugin of TensorRT 8.5. | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cub/cub.cuh> | ||
#include "core/providers/cuda/cu_inc/common.cuh" | ||
#include "contrib_ops/cuda/diffusion/bias_add_impl.h" | ||
|
||
using namespace onnxruntime::cuda; | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
template <typename T, int32_t C, int32_t TPB> | ||
__global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, T* output) { | ||
int32_t base_offset = blockIdx.x * C + threadIdx.x; | ||
int32_t bias_offset = threadIdx.x; | ||
|
||
#pragma unroll | ||
for (int32_t i = 0; i < C / TPB; ++i) { | ||
output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset]; | ||
base_offset += TPB; | ||
bias_offset += TPB; | ||
} | ||
} | ||
|
||
template __global__ void BiasAddKernel<float, 320, 320>(float const*, float const*, float const*, float*); | ||
template __global__ void BiasAddKernel<float, 640, 320>(float const*, float const*, float const*, float*); | ||
template __global__ void BiasAddKernel<float, 1280, 320>(float const*, float const*, float const*, float*); | ||
template __global__ void BiasAddKernel<half, 320, 320>(half const*, half const*, half const*, half*); | ||
template __global__ void BiasAddKernel<half, 640, 320>(half const*, half const*, half const*, half*); | ||
template __global__ void BiasAddKernel<half, 1280, 320>(half const*, half const*, half const*, half*); | ||
|
||
template <typename T> | ||
void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, | ||
T const* input, T const* bias, T const* residual, T* output) { | ||
constexpr int32_t TPB = 320; // thread per block | ||
switch (num_channels) { | ||
case 320: | ||
(BiasAddKernel<T, 320, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output); | ||
break; | ||
case 640: | ||
(BiasAddKernel<T, 640, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output); | ||
break; | ||
case 1280: | ||
(BiasAddKernel<T, 1280, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output); | ||
break; | ||
default: | ||
ORT_NOT_IMPLEMENTED("Not implemented"); | ||
} | ||
} | ||
|
||
template void LaunchBiasAddKernel<float>(cudaStream_t stream, int32_t grid_size, int32_t num_channels, | ||
float const* input, float const* bias, float const* residual, float* output); | ||
|
||
template void LaunchBiasAddKernel<half>(cudaStream_t stream, int32_t grid_size, int32_t num_channels, | ||
half const* input, half const* bias, half const* residual, half* output); | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/common/status.h" | ||
#include <cuda.h> | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
template <typename T> | ||
void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, | ||
T const* input, T const* bias, T const* residual, T* output); | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.