Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Smooth Softmax in GroupQueryAttention #21867

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>smooth_softmax</tt> : int</dt>
<dd>Use a smooth factor in softmax.</dd>
</dl>

#### Inputs (7 - 9)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct GroupQueryAttentionParameters {
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
bool use_smooth_softmax;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
43 changes: 42 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,47 @@
namespace onnxruntime {
namespace contrib {

template <typename T>
void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t j = begin; j != end; ++j) {
float* x = reinterpret_cast<T*>(score) + j * D;
float* y = x;

float max = -std::numeric_limits<float>::infinity();
for (int i = 0; i < D; i++) {
if (max < x[i])
max = x[i];
}

if (max < 0.0f) {
max = 0.0f;
}

for (int i = 0; i < D; i++) {
y[i] = expf(x[i] - max);
}

double sum = 0.0;

for (int i = 0; i < D; i++) {
sum += x[i];
}

sum += exp(static_cast<double>(-max));

for (int i = 0; i < D; i++) {
y[i] = x[i] / (float)sum;

Check warning on line 49 in onnxruntime/contrib_ops/cpu/bert/attention_helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_helper.h:49: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
}
}
});
}

template <>
inline void ComputeSmoothSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, true, tp);
}

template <typename T>
void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
Expand Down Expand Up @@ -58,7 +99,7 @@

template <>
inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, tp);
MlasComputeSoftmax(score, score, N, D, false, false, tp);
}

template <typename T>
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class GQAAttentionBase {
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;

use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;

local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;
}

Expand All @@ -40,6 +42,8 @@ class GQAAttentionBase {
bool rotary_interleaved_;
int local_window_size_;

bool use_smooth_softmax_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
Expand Down Expand Up @@ -195,10 +199,19 @@ class GQAAttentionBase {
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
}
} else {
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
}
}

// set causal [seq_causal_length, total_seqlen) to 0.f
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ Status FlashAttention(
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false,
parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ struct Flash_fwd_params : public Qkv_params {

bool is_rotary_interleaved = false;

bool smooth_softmax = false;

int num_splits = 0; // For split-KV version

void* __restrict__ alibi_slopes_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void set_params_fprop(Flash_fwd_params& params,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
Expand All @@ -47,6 +48,7 @@ void set_params_fprop(Flash_fwd_params& params,
params.o_ptr = out;

params.is_bf16 = is_bf16;
params.smooth_softmax = use_smooth_softmax;

// All stride are in elements, not bytes.
if (kv_bsnh) {
Expand Down Expand Up @@ -267,6 +269,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
Expand All @@ -293,6 +296,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
use_smooth_softmax,
kv_bsnh,
local_window_size,
is_causal ? 0 : -1);
Expand Down Expand Up @@ -365,6 +369,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
false,
true,
-1,
is_causal ? 0 : -1);
Expand Down Expand Up @@ -424,6 +429,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool past_bsnh, // otherwise bnsh
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
Expand Down Expand Up @@ -456,6 +462,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
use_smooth_softmax,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
Expand Down Expand Up @@ -105,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool past_bsnh, // otherwise bnsh
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax);

// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Expand Down Expand Up @@ -902,7 +902,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax, params.smooth_softmax);

Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ struct Softmax {
};

template <bool Split = false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = inv_sum;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;

kernel_options_ = this->GetAttentionKernelOptions();

Expand Down Expand Up @@ -98,6 +99,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;
Expand Down Expand Up @@ -151,6 +153,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
bool do_rotary_;
bool rotary_interleaved_;
bool use_smooth_softmax_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,9 @@
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,

Check warning on line 681 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:681: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),

Check warning on line 682 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:682: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));

// if (parameters.left_padding && parameters.is_prompt) {
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("smooth_softmax",
"Use a smooth factor in softmax.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ MlasComputeSoftmax(
size_t N,
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
MLAS_THREADPOOL* ThreadPool
);

Expand Down
Loading
Loading