Skip to content

Commit

Permalink
[ROCm] Support for gpt2-based model inferencing (#14675)
Browse files Browse the repository at this point in the history
When inferencing real gpt2-based model, found some gaps between CUDA and
ROCm codebase.

The fixes include:

1. minimum code change to fix tensor shape on Attention Op
2. Support optional output tensor with SkipLayerNorm
3. fix a build error found on MI200

---------

Co-authored-by: Ubuntu <ettao@ettao-amd-dev1.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
  • Loading branch information
ytaous and Ubuntu authored Feb 15, 2023
1 parent a216c9a commit d49cea0
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 95 deletions.
76 changes: 40 additions & 36 deletions onnxruntime/contrib_ops/rocm/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {

constexpr int kPastSequenceLengthInputIndex = 6;
constexpr int kPastInputIndex = 4;
constexpr int kPresentOutputIndex = 1;

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
(*KernelDefBuilder::Create()) \
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
Attention<T>);

REGISTER_KERNEL_TYPED(float)
Expand All @@ -40,47 +46,41 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const Tensor* relative_position_bias = context->Input<Tensor>(5);
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
mask_index,
past,
relative_position_bias,
nullptr,
device_prop.maxThreadsPerBlock));

// input shape (batch_size, sequence_length, input_hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int input_hidden_size = static_cast<int>(shape[2]);

// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;

int head_size = hidden_size / num_heads_;
&parameters,
device_prop.maxThreadsPerBlock,
past_seq_len));
ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention

TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
std::vector<int64_t> present_dims{
2, parameters.batch_size, parameters.num_heads,
parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length,
parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present = context->Output(kPresentOutputIndex, present_shape);

rocblas_handle rocblas = GetRocblasHandle(context);
constexpr size_t element_size = sizeof(T);

// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size, context->GetComputeStream());
int m = parameters.batch_size * parameters.sequence_length;
int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size);
int k = parameters.input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());

typedef typename ToHipType<T>::MappedType HipT;
namespace blas = rocm::tunable::blas;
Expand Down Expand Up @@ -108,8 +108,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
/*beta=*/1.0f,
reinterpret_cast<HipT*>(gemm_buffer.get()), n));

size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size,
sequence_length, past_sequence_length);
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
parameters.head_size,
parameters.sequence_length,
parameters.past_sequence_length);

auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
return LaunchAttentionKernel(
Expand All @@ -118,16 +122,16 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Stream(context),
rocblas,
element_size,
batch_size,
sequence_length,
num_heads_,
head_size,
past_sequence_length,
is_unidirectional_,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
parameters.past_sequence_length,
parameters.is_unidirectional,
reinterpret_cast<const void*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
mask_filter_value_,
parameters.mask_filter_value,
nullptr == past ? nullptr : past->Data<T>(),
nullptr == relative_position_bias ? nullptr : relative_position_bias->Data<T>(),
work_space.get(),
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {

Tensor* output = ctx->Output(0, input->Shape());

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());

if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
Expand Down Expand Up @@ -101,6 +105,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
GetTuningContext(),
Stream(ctx),
reinterpret_cast<HipT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<HipT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ namespace rocm {

template <typename T>
Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, int ld, int element_count) {
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) {
// this must be true because element_count is the total size of the tensor
assert(element_count % ld == 0);

SkipLayerNormParams<T> params(tuning_ctx, stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
SkipLayerNormParams<T> params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, gamma, beta, bias, epsilon, ld, element_count);

if (tuning_ctx->IsTunableOpEnabled()) {
static SkipLayerNormTunableOp<T> op;
Expand All @@ -57,13 +57,13 @@ Status LaunchSkipLayerNormKernel(
}

template Status LaunchSkipLayerNormKernel<float>(
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, const float* input,
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, float* skip_input_bias_add_output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, int ld,
int element_count);

template Status LaunchSkipLayerNormKernel<half>(
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, const half* input,
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, half* skip_input_bias_add_output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, int ld,
int element_count);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning,
hipStream_t stream,
T* output, // output tensor
T* skip_input_bias_add_output, // optional output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
Expand Down
37 changes: 32 additions & 5 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ half maybe2half(float x) {
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T epsilon, T* output, T* skip_input_bias_add_output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;

Expand All @@ -39,6 +39,11 @@ __global__ void SkipLayerNormKernel(
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));

if (skip_input_bias_add_output != nullptr) {
skip_input_bias_add_output[idx] = val;
}

output[idx] = val;
}

Expand All @@ -49,7 +54,8 @@ __global__ void SkipLayerNormKernel(
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelVec(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;

Expand All @@ -58,7 +64,7 @@ __global__ void SkipLayerNormKernelVec(
hipcub::KeyValuePair<T, T> thread_data(0, 0);

using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];;
if (threadIdx.x * ILP < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
Expand All @@ -76,9 +82,19 @@ __global__ void SkipLayerNormKernelVec(
#pragma unroll
for (int k = 0; k < ILP; k++) {
input_v[k] += hasBias ? skip_v[k] + bias_v[k] : skip_v[k];

if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
}

const T rldval = reverse_ld * input_v[k];
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * input_v[k]));
}

if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
}

*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&input_v[0]);
}
}
Expand All @@ -90,12 +106,13 @@ __global__ void SkipLayerNormKernelVec(
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld

using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];

hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));

Expand All @@ -116,10 +133,20 @@ __global__ void SkipLayerNormKernelSmall(
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];

if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
}

const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}

if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
}

thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
Expand Down
Loading

0 comments on commit d49cea0

Please sign in to comment.