From 5ca83a9c7138d985779f41ff85eaad40760fa24f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 22 Jul 2024 23:42:06 -0700 Subject: [PATCH] Clean up softcapping bwd a bit --- README.md | 2 +- csrc/flash_attn/src/flash_bwd_kernel.h | 21 +++++-------------- .../src/flash_bwd_launch_template.h | 2 +- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index b6efd8ee3..1d0897ab5 100644 --- a/README.md +++ b/README.md @@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution. ### 2.6: Softcapping. Support attention with softcapping, as used in Gemma-2 and Grok models. -Thanks to @Narsil for this contribution. +Thanks to @Narsil and @lucidrains for this contribution. ## Performance diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 00cbc081e..4f95bd34a 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread(32, 0)) { print(scores); } // Softcapping - calculating dTanh and scaling dS later with it - auto dtanh = ([&]{ - if constexpr (Is_softcap) { - Tensor _dtanh = make_tensor_like(scores); - flash::calculate_dtanh(scores, _dtanh, params.softcap); - return _dtanh; - } - else { - return nullptr; - } - }()); + Tensor dtanh = make_tensor_like(scores); + if constexpr (Is_softcap) { + flash::calculate_dtanh(scores, dtanh, params.softcap); + } // Alibi if (Has_alibi) { @@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); - - if constexpr (Is_softcap) { - scaled_ds *= dtanh(mi, ni); - } - + if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2e141c8fe..5569466ad 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute(