Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

b3631 #309

Merged
merged 11 commits into from
Aug 27, 2024
Prev Previous commit
Next Next commit
metal : separate scale and mask from QKT in FA kernel (ggml-org#9189)
* metal : separate scale and mask from QKT in FA kernel

* metal : ne01 check no longer necessary

* metal : keep data in local memory
  • Loading branch information
ggerganov authored Aug 26, 2024
commit 06658ad7c37f440502de2b9486ce43c47b4ec710
35 changes: 13 additions & 22 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
}

simdgroup_store(mqk, ss + 8*cc, TF, 0, false);

const short tx = tiisg%4;
const short ty = tiisg/4;

// mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;

if (logit_softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
}

if (mask != q) {
// mqk = mqk + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
}
}

Expand All @@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
float ms[Q];

for (short j = 0; j < Q; ++j) {
const short p = tiisg;

const float m = M[j];
const float s = ss[j*TF + p];

// scale and apply the logitcap / mask
float s = ss[j*TF + tiisg]*scale;

if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
}

if (mask != q) {
// mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
}

smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
Expand All @@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
S[j] = S[j]*ms[j] + simd_sum(vs);

// the P matrix from the paper (Q rows, C columns)
ss[j*TF + p] = vs;
ss[j*TF + tiisg] = vs;
}

// create a QxQ diagonal matrix for rescaling the output
Expand Down