-
Notifications
You must be signed in to change notification settings - Fork 10.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU/CUDA: Gemma 2 FlashAttention support #8542
CPU/CUDA: Gemma 2 FlashAttention support #8542
Conversation
Currently, the Metal kernels always use F32 accumulators regardless of the selected precision so this should not be a problem You can fix the failing Metal tests like this: diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index b5939efa..d16f1c6e 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -798,6 +798,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
if (op->src[0]->ne[0] == 256) {
return false;
}
+ {
+ float logit_softcap;
+
+ memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
+
+ if (logit_softcap != 0.0f) {
+ return false;
+ }
+ }
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID: We'll implement support in the future. |
It seems Gemma 2 with Flash Attention and quantized KV Cache + partial offloading slows down pompt processing a lot. Is this expected behavior? FA without quantized KV cache is fine though. |
I didn't test that particular case, do you have similar issues with other models or only with Gemma 2? |
Yes, Llama 3 8B is unaffected. FA without qKV cache (-n 180 -c 4096 -t 6 --gpu-layers 25 --ignore-eos -fa), RTX 2060, i7 9750H, Gemma 2 Q4_K_S
FA with qKV cache (same as above + -ctk q8_0 -ctv q8_0)
Llama 8B q4_K_S (17 GPU layers) FA without qKV cache
FA with qKV cache
Here's a comparison for these two models (partially offloaded) As you can see with Gemma 2, as soon as you quantize the kv cache, the generation and especially prompt processing speed slows down a lot which is not the case with Llama 3 8B (text gen even increases nicely). |
I found some weird problems using -fa with this branch. Testing imatrix version on Q2 and Q3 of Tiger-Gemma 27b model completely broke the responses with the model answering non-sense for any kind of prompt. While not running with flash-attn seems to work normally. I've not tested with the 9b version. Interesting enough that using Alpaca template worked with the flash-attn, but the tokenization provided by Gemma would generate infinite response of "Manneur Manneur Manneur Manneur Manneur..." I also confirmed the slow token initilization mentioned by Dampfinchen on lower quants. |
There are potentially numerical issues that only appear with specific models. I'm assuming you downloaded the models off of Huggingface; can you link them. Also, do the non-imatrix models work correctly? |
https://huggingface.co/mradermacher/Big-Tiger-Gemma-27B-v1-i1-GGUF I've tested the gemma-27b as well https://huggingface.co/mradermacher/gemma-2-27b-it-i1-GGUF The issue I mentioned happens with all of them when using |
apply logit_softcap to scale in kernel disable logit softcapping tests on Metal
apply logit_softcap to scale in kernel disable logit softcapping tests on Metal
If nobody else is going to ask, I will! Has this pr been overlooked? |
src/llama.cpp
Outdated
#ifdef GGML_USE_METAL | ||
if (params.flash_attn && model->hparams.attn_soft_cap) { | ||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); | ||
params.flash_attn = false; | ||
} | ||
#endif // GGML_USE_METAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't add exceptions for specific backends like this. Eventually backends will be loaded at runtime, and these macros will be removed. It is also not completely correct, because it is possible to use -ngl 0
with metal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it with the current code possible to do this correctly in a simple way? With a debugger I can see that model->bufs
is already set and (I think) lets you make conclusions about the available backends. But I don't see a way in ggml-backend.h
for converting ggml_backend_buffer_t
to ggml_backend_t
which I could then feed to the functions to check which backends are in use.
Alternatively, if there was a Metal implementation then this issue would also be resolved. Has there been any progress on that front?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just remove the check entirely and let's merge this as is, I will open a PR later with a Metal implementation. If the supports_op
implementation is correct, it will fallback to CPU, but it will not crash.
@@ -1,7 +1,7 @@ | |||
#include "common.cuh" | |||
#include "fattn-common.cuh" | |||
|
|||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size | |||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it affect performance significantly to do this check at runtime? Otherwise this is going to double the number of template instances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When evaluating 8192 tokens with a batch size of 1 on my 3090 a runtime check reduces the end-to-end performance by ~19%. For a batch size of 512 the difference is ~1%. But the latter is only because the kernels using tensor cores are poorly written; due to the use of the WMMA interface data has to be written to and read from shared memory which prevents the compiler from properly reordering instructions (so it doesn't matter if a conditional statement for logit softcapping also does this). My plans are to re-write the kernels as part of my work on the LLM backwards pass in the coming months, I expect the difference for large batches to then be similar to the difference for batch size 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyways, regarding how to move forward: the compilation time/binary size does not grow by a factor of 2 from this PR because you only need the template instances for a head size of 128 or 256, the rest can be skipped. However, because larger head sizes contribute a disproportionate amount to the compilation time/binary size the impact of FlashAttention would still grow by a factor of ~1.7. If that is judged to be too much considering right now there is effectively only a single model that would benefit, I would suggest making Gemma 2 support opt-in in a similar way to GGML_CUDA_FA_ALL_QUANTS
.
We should then also think about how to handle opt-in features in general. It's basically guaranteed that code for training would add further to compilation time/binary size though most users will not need it. But that case can be covered relatively easily (for llama.cpp) by building the training code only when building at least one example that uses it. Still, for ggml we should maybe add one define per optional CUDA feature and then also a single define like GGML_CUDA_FULL
that enables all optional features for convenience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not concerned about the binary size, but it would be nice to have some option to accelerate the compilation time. I think the simplest way to do that would be to build everything by default, and add options to disable some of the functionality (make it opt-out, rather than having to enable every feature or add a GGML_CUDA_FULL
). I don't think it is a big issue at the moment, and I don't think we need to make Gemma 2 support opt-in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slaren Can this be merged? I would very much like the performance improvements through Flash Attention support for the Gemma 2 model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem here is not so much slaren but rather me. Basically the only limiting factor for me is how much time I have so I'm trying to optimize throughput rather than latency when it comes to features. My impression was that working on other things first will result in less work for overall so I've put this PR on the backburner.
c325464
to
6e40804
Compare
Like @vitorfdl, since updating after this merge, Gemma 2 27B refuses to work for me with -fa. Every response is:
I'm using Bartowski's gemma-2-27b-it-Q6_K imatrix quant. sampling params
Example prompt:
|
I am unable to reproduce the issue. What hardware are you using? Can you post the exact command with which the problem occurs? |
WSL2, Ubuntu 22.04.4 LTS on Windows 11, Nvidia Geforce 3090.
result:
until I cancel. (Without --special there's no output) Remove --flash-attn and it works as expected:
It seems to come down to whether NGL is used with --flash-attn, on Gemma 2. Simplified settings with FA and without NGL:
vs FA with
|
It seems it also works with partial offloading.
|
It also happens with Metal in #9159. |
Thank you, I was able to reproduce the issue. Setting the precision to FP32 fixes it for me: #9166 . I think to vaguely recall that the Metal FA implementation always uses FP32 precision anyways though. |
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This PR adds FlashAttention support for Gemma 2 on the CPU and CUDA backends by adding another parameter that controls the logit softcap. A value of
0.0f
indicates no logit softcapping. When specifying a different value thescale
parameter is divided by this value and prior to the softmax the tangens hyperbolicus as well as the scaling bylogit_softcap
is applied. Because this changes the position of the parameter that indicates the precision needed for FlashAttention I think that this PR breaks the Metal FlashAttention implementation (I am not aware of any other FlashAttention implementations). I tried searching for the spot where the Metal code retrieves the FlashaAttention precision but I was not able to find it; help would be very much appreciated.This PR adds template instances with logit softcapping for the CUDA FlashAttention kernels for head sizes 128 (Gemma 2 27b) and 256 (Gemma 2 9b). For all other head sizes no instance is compiled since (to my knowledge) there are no models that would use them. The FlashAttention kernels that do not use FP16 tensor cores only support head sizes 64 and 128 so Gemma 2 9b is only supported on NVIDIA GPUs with compute capability >= 7.0. For simplicity
tests/test-backend-ops.cpp
only checks head size 128 (which should be enough since logit softcapping is a scalar operation and does not depend on head size).GPU Performance