Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Smooth Softmax in fmha #21885

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"

#include <cuda_runtime.h>
+#include <cuda_fp16.h>

#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
Expand All @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ Status EfficientAttention(
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
p.use_smooth_softmax = false;

if (nullptr == data.mask_index) {
p.seqlen_k_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.bias_strideM = 0;
p.bias_strideB = 0;
}

p.use_smooth_softmax = params.use_smooth_softmax;
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams {
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
bool use_smooth_softmax;

float scale;

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ Status FlashAttention(
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));

// if (parameters.left_padding && parameters.is_prompt) {
Expand Down Expand Up @@ -843,6 +843,7 @@ Status EfficientAttention(
: nullptr;
p.stream = stream;
p.has_custom_right_padding = true;
p.use_smooth_softmax = parameters.use_smooth_softmax;
run_memory_efficient_attention(p);

DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ Status FusedAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_no_past_flash_attention_test_cases())
Expand Down Expand Up @@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
Expand Down
Loading