Skip to content

Commit

Permalink
smooth softmax for fmha
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Aug 28, 2024
1 parent 23f3912 commit 0dc380c
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 9 deletions.
59 changes: 55 additions & 4 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1,64 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
}
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"

#include <cuda_runtime.h>
+#include <cuda_fp16.h>

#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
Expand All @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.bias_strideM = 0;
p.bias_strideB = 0;
}

p.use_smooth_softmax = params.use_smooth_softmax;
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams {
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
bool use_smooth_softmax;

float scale;

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ Status FlashAttention(
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));

// if (parameters.left_padding && parameters.is_prompt) {
Expand Down Expand Up @@ -843,6 +843,7 @@ Status EfficientAttention(
: nullptr;
p.stream = stream;
p.has_custom_right_padding = true;
p.use_smooth_softmax = parameters.use_smooth_softmax;
run_memory_efficient_attention(p);

DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_no_past_flash_attention_test_cases())
Expand Down Expand Up @@ -2263,7 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
Expand Down

0 comments on commit 0dc380c

Please sign in to comment.