Skip to content

Commit

Permalink
Zen4 Flash Attention (#32)
Browse files Browse the repository at this point in the history
* Zen4 flash attention: moving useful parts from the kq_fused_softmax branch

* Add flash attention with soft-cap and fix D = 256 case

* Flash attention refinements

* Update FlashAttn comment

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
  • Loading branch information
ikawrakow and Kawrakow authored Sep 1, 2024
1 parent dbb1db9 commit dc023bc
Show file tree
Hide file tree
Showing 3 changed files with 634 additions and 0 deletions.
32 changes: 32 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16149,6 +16149,38 @@ static void ggml_compute_forward_flash_attn_ext_f16(
scale /= softcap;
}

#if GGML_USE_IQK_MULMAT
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
mask && mask->type == GGML_TYPE_F16) {
int64_t work_per_slice = D*nek1*neq1;
int ntg = 1;
if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
if ((neq2*neq3)%(nth/ntg) == 0) {
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1/ntg;
if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
scale, softcap,
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
}
}
}
return;
}
IQK_Flash_Attn_NotAvailable:;
}

#endif

const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

Expand Down
Loading

0 comments on commit dc023bc

Please sign in to comment.