Skip to content

Commit

Permalink
metal : separate scale and mask from QKT in FA kernel (ggerganov#9189)
Browse files Browse the repository at this point in the history
* metal : separate scale and mask from QKT in FA kernel

* metal : ne01 check no longer necessary

* metal : keep data in local memory
  • Loading branch information
ggerganov authored and arthw committed Nov 18, 2024
1 parent 9ff4df4 commit c99adef
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
}

simdgroup_store(mqk, ss + 8*cc, TF, 0, false);

const short tx = tiisg%4;
const short ty = tiisg/4;

// mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;

if (logit_softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
}

if (mask != q) {
// mqk = mqk + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
}
}

Expand All @@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
float ms[Q];

for (short j = 0; j < Q; ++j) {
const short p = tiisg;

const float m = M[j];
const float s = ss[j*TF + p];

// scale and apply the logitcap / mask
float s = ss[j*TF + tiisg]*scale;

if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
}

if (mask != q) {
// mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
}

smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
Expand All @@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
S[j] = S[j]*ms[j] + simd_sum(vs);

// the P matrix from the paper (Q rows, C columns)
ss[j*TF + p] = vs;
ss[j*TF + tiisg] = vs;
}

// create a QxQ diagonal matrix for rescaling the output
Expand Down

0 comments on commit c99adef

Please sign in to comment.