-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml:Mamba Cuda kernel performance improve #9186
Conversation
…end + test case for each op
…(llama : use im2col and mul_mat to perform convolution for Mamba); GPU version breaks with assert because of unsupported MUL_MAT
Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version. More specifically, the new one seems to ignore the prompt, e.g. Original (from my Mamba tinystories test case):
New:
(I also checked a few other examples, also with Falcon Mamba 7B Instruct - it constistently outputs text unrelated to the prompt.) |
Generally speaking the outputs will always change if the code is changed due to differences in floating point rounding error. And due to the numerical instability of neural networks there is no guaranteed upper bound for how much these rounding errors will blow up for end-to-end evaluation. I think the best way to check for these things is the |
Good point, but here the change in output is so striking that it is visible with the naked eye... |
ggml/src/ggml-cuda/ssm_conv.cu
Outdated
const float * src0, const float * src1, | ||
const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
const int src1_nb1, | ||
float * dst, | ||
const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
const int nc, const int ncs, const int nr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add __restrict__
to the pointers, see #2140 .
ggml/src/ggml-cuda/ssm_conv.cu
Outdated
const int tid = threadIdx.x; | ||
const int i2 = blockIdx.x; | ||
const int i3 = threadIdx.y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be aware that this is not a noop. There are special, shared registers for e.g. threadIdx.x
and you are taking that data and moving it to regular registers. IIRC correctly the regular registers are slightly faster to access but it will also increase register pressure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.I believe that most GPU registers are more than sufficient to meet the requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think I worded my previous post poorly. What I meant is that the regular registers are slightly faster (but there is only a limited amount of them).
ggml/src/ggml-cuda/ssm_conv.cu
Outdated
#pragma unroll | ||
for (int i1 = 0; i1 < ir; ++i1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ir
is not known at compile time so this loop cannot actually be unrolled, same for the other loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
float sumf = 0.0f; | ||
#pragma unroll | ||
for (int i0 = 0; i0 < nc; ++i0) { | ||
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2
shifts the view by one column.
In practice, for Mamba (and Mamba-2) models, nc
is always 4, which might help with unrolling.
To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc
steps at a time over i2
, assuming the WARP_SIZE
is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2
is not evenly divided by that.
I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.
Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with
i2
shifts the view by one column.In practice, for Mamba (and Mamba-2) models,
nc
is always 4, which might help with unrolling.To coalesce memory accesses (at least for large prompts), I guess each warp could operate on
WARP_SIZE/nc
steps at a time overi2
, assuming theWARP_SIZE
is a multiple of 4 (is that always the case?), but this might need special handling of cases wherei2
is not evenly divided by that.I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.
Good Idea,I am currently testing according to your method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with
i2
shifts the view by one column.In practice, for Mamba (and Mamba-2) models,
nc
is always 4, which might help with unrolling.To coalesce memory accesses (at least for large prompts), I guess each warp could operate on
WARP_SIZE/nc
steps at a time overi2
, assuming theWARP_SIZE
is a multiple of 4 (is that always the case?), but this might need special handling of cases wherei2
is not evenly divided by that.I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.
I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm feeling at a loss for optimization ideas.
Use one warp per iteration of nc
with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum
and have the first thread in the warp write back the result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated.
ggml/src/ggml-cuda/ssm_scan.cu
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comments for ssm_conv.cu
largely apply here as well.
Note that
The test passes on fae826f |
Hi,I fix the numerical err on new commit and done some modifications according to @JohannesGaessler ’s opinion, but there is still a lot of room for performance improvement(like:memory coalesced).The SCAN op performance now is 0.53GB/s.Only slightly better than @jploski ’s version. I will attempt to optimize these ops on the current basis, but which may take a long time.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should mention that I know nothing about Mamba in terms of architecture or correctness, I was only commenting on the CUDA code in terms of optimization since I'm probably the contributor with the single most experience in that area.
float sumf = 0.0f; | ||
#pragma unroll | ||
for (int i0 = 0; i0 < nc; ++i0) { | ||
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm feeling at a loss for optimization ideas.
Use one warp per iteration of nc
with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum
and have the first thread in the warp write back the result.
Thank u for u valuable input, especially for your expert guidance on CUDA code optimization. |
@jploski @ggerganov @compilade Hi guys, do u have any comments on this PR? I can’t see any areas for optimization at the moment, and the unit tests are passing.Haha,I’ve discovered that my implementation is very similar to the Meta version.The current version might be the fastest to run on NV GPUs |
Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet). What I'm wondering about is the following quote from the original Mamba paper: "Concretely, instead of preparing the scan input (𝑨, 𝑩) of size (B, L, D, N) in GPU HBM (high-bandwidth memory), we load Does the current implementation already take advantage of it? As far as I understand, the distinction between SRAM vs. HBM refers to utilizing shared memory in CUDA kernels. What I do not understand is whether the "fast" path in Python does anything explicit to manage this memory, or if it is already a side effect of the base algorithm. One thing which I also haven't done yet is to measure the speed of generation of the "fast" Python implementation and compare to this PR. I recall benchmarks showing that Mamba's generation is slower than transformer's for small context lengths, and only surpasses it after some threshold (because it's constant time per token, unlike in transformer). So as a smoke test, it may also be good to confirm that it holds for our implementation. |
This point has not been utilized. The implementation in paper employs a tile algorithm similar to flash-attention, which is quite complex overall. I think it will be a substantial project and should be approached step by step |
If it’s the implementation from the paper, it should be faster because its operators are specifically optimized for CUDA. However, it has certain requirements for the CUDA version. As for minimamba, personally, I feel it should be faster than that.I will test.
Fun,Is the output of your version the same as that of the CPU(tinystories model)? |
No, it also diverges after some initial tokens. |
correctness test./build/bin/llama-batched -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -p "Hello, my name is" -np 16 -n 32 result
GPU version perplexity test
CPU version perplexity test
|
I had done perplexity test, the PPL value appears to be the same as the CPU version. Therefore, I believe this ops is correct |
I found it a bit sus that your PPL values matched exactly, so I ran my own test (./llama-perplexity -m /mnt/f2fs/mamba/tinystories_f32.gguf -f scripts/wikitext-2-raw/wiki.test.raw -ngl ...). The results: CPU: Final estimate: PPL = 957.3512 +/- 9.12936 So yes, I think this is ready to be unleashed on the unsuspecting world, despite of what our very friendly kindordial personordials might say... |
For which batch sizes are these new kernels used?
Unless the vocabulary size for the model is absolutely gargantuan those values look to me like a bug. A value < 10 should be expected for example for LLaMA 3. I never tested TinyStories but I can't imagine that it's that bad at predicting Wikitext tokens.
That value also looks suspiciously large to me. |
The tinystories model I'm using here is one I trained myself off mamba-130m-hf, which is a small model, and the TinyStories dataset is intentionally very restricted in content type. The outputs my model produces are pretty bad overall (possibly not just because it's just a "silly" dataset, but the model is also undertrained on it). So I wouldn't be paying much attention to the absolute value of PPL for wiki. The tokenizer vocabulary here is the same as mamba-130-hf. In that test I was mostly interested whether the CPU perplexity comes out exactly the same as GPU as in @piDack's experiment.
By LLaMA 3 are you referring to a big model? The Mamba models we both used for computing PPL here are both just 130M parameters, so the PPL=22 estimate of @piDack may be all right after all. |
Once again, thks 4 u code review. I tested the perplexity with different batch sizes with GPU# batch = 3
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 3
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 5
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 19
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 31
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 32
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99
CPUrecompile use openmp as backend.And use make clean
make -j llama-perplexity
./llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ub 31 Final estimate: PPL = 22.5590 +/- 0.17848 Metaform #8546 PPL = 25.0894 +/- 0.1855 |
hello,i am wondering how to test the speed, what tools to use to get the value 2.86GB/s? |
I can't comment on that particular metric. Generally you can profile individual kernels with NVIDIA Nsight Compute (the command ncu or ncu-ui). You can simply let this tool run the usual llama.cpp command line to generate tokens while specifying upfront the kernel you want to collect performance data about (the "Filter" tab in ncu-ui). However, note that for optimizations to make practical sense you should also keep an eye on how much that kernel's execution contributes to the overall execution time, (in particular the per-token memory transfers between GPU and host might dominate the execution time). |
got it, thank you, i will have a try. |
I think this PR has been superseded by PR #10558. |
For #6758 . I had Optimized the performance by 10 times based @jploski pr https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda .Under the current test case, the performance of conv has improved from 0.62 GB/s to 6.47 GB/s. The performance of scan has increased from 0.30 GB/s to 2.86 GB/s