Skip to content

Commit

Permalink
metal : keep data in local memory
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 26, 2024
1 parent e865686 commit ff23e8e
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2143,21 +2143,6 @@ kernel void kernel_flash_attn_ext_f16(
}
}

// scale and apply the mask (assume C = 32)
for (short j = 0; j < Q; ++j) {
// mqk = mqk*scale
ss[j*TF + tiisg] *= scale;

if (logit_softcap != 0.0f) {
ss[j*TF + tiisg] = logit_softcap*precise::tanh(ss[j*TF + tiisg]);
}

if (mask != q) {
// mqk = mqk + mask*slope
ss[j*TF + tiisg] += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
}
}

// used to detect blocks full of -INF
float smax = -INFINITY;

Expand All @@ -2167,7 +2152,18 @@ kernel void kernel_flash_attn_ext_f16(

for (short j = 0; j < Q; ++j) {
const float m = M[j];
const float s = ss[j*TF + tiisg];

// scale and apply the logitcap / mask
float s = ss[j*TF + tiisg]*scale;

if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
}

if (mask != q) {
// mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
}

smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
Expand Down

0 comments on commit ff23e8e

Please sign in to comment.