Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal : gemma2 flash attention support #9159

Merged
merged 2 commits into from
Aug 26, 2024
Merged

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Aug 24, 2024

Performance looks unchanged despite the new parameter, but I have only tested this with test-backend-ops.

@github-actions github-actions bot added the testing Everything test related label Aug 24, 2024
@slaren slaren marked this pull request as draft August 25, 2024 20:16
@ggerganov
Copy link
Owner

The following patch seems to fix the issue from #8542 on my Mac:

diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index ab2de69c..aba0b9a0 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2149,8 +2149,8 @@ kernel void kernel_flash_attn_ext_f16(
                     ss[8*cc + ty*TF + 2*tx + 1] *= scale;
 
                     if (logit_softcap != 0.0f) {
-                        ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*tanh(ss[8*cc + ty*TF + 2*tx + 0]);
-                        ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*tanh(ss[8*cc + ty*TF + 2*tx + 1]);
+                        ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
+                        ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
                     }
 
                     if (mask != q) {
@@ -2490,7 +2490,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
                         mqk *= scale;
 
                         if (logit_softcap != 0.0f) {
-                            mqk = logit_softcap*tanh(mqk);
+                            mqk = logit_softcap*precise::tanh(mqk);
                         }
 
                         mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;

@slaren
Copy link
Collaborator Author

slaren commented Aug 26, 2024

Yep, also fixes it for me.

@slaren slaren marked this pull request as ready for review August 26, 2024 08:51
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't observe performance changes as well, should be good to merge

@slaren slaren merged commit 0c41e03 into master Aug 26, 2024
49 of 52 checks passed
@slaren slaren deleted the sl/metal-logit-softcap branch August 26, 2024 09:09
@ggerganov ggerganov mentioned this pull request Aug 26, 2024
4 tasks
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants