diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 84e92e4fce574..34194abc10667 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2143,21 +2143,6 @@ kernel void kernel_flash_attn_ext_f16( } } - // scale and apply the mask (assume C = 32) - for (short j = 0; j < Q; ++j) { - // mqk = mqk*scale - ss[j*TF + tiisg] *= scale; - - if (logit_softcap != 0.0f) { - ss[j*TF + tiisg] = logit_softcap*precise::tanh(ss[j*TF + tiisg]); - } - - if (mask != q) { - // mqk = mqk + mask*slope - ss[j*TF + tiisg] += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; - } - } - // used to detect blocks full of -INF float smax = -INFINITY; @@ -2167,7 +2152,18 @@ kernel void kernel_flash_attn_ext_f16( for (short j = 0; j < Q; ++j) { const float m = M[j]; - const float s = ss[j*TF + tiisg]; + + // scale and apply the logitcap / mask + float s = ss[j*TF + tiisg]*scale; + + if (logit_softcap != 0.0f) { + s = logit_softcap*precise::tanh(s); + } + + if (mask != q) { + // mqk = mqk + mask*slope + s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; + } smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s));