Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml:Mamba Cuda kernel performance improve #9186

Closed
wants to merge 20 commits into from

Conversation

piDack
Copy link
Contributor

@piDack piDack commented Aug 26, 2024

For #6758 . I had Optimized the performance by 10 times based @jploski pr https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda .Under the current test case, the performance of conv has improved from 0.62 GB/s to 6.47 GB/s. The performance of scan has increased from 0.30 GB/s to 2.86 GB/s

@piDack piDack changed the title Mamba Cuda kernel performance improve ggml:Mamba Cuda kernel performance improve Aug 26, 2024
@jploski
Copy link
Contributor

jploski commented Aug 26, 2024

Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version. More specifically, the new one seems to ignore the prompt, e.g.

Original (from my Mamba tinystories test case):

Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to have some hot chocolate.
"Look at my snowman!" Sara says. "He is so pretty! He can talk and sing."
Ben looks up and sees his snowman's face. He smiles and nods.

New:

Sara and Ben are playing in the snow.
Molly was a three year old girl who loved to play with her friends. One day, she decided to go for a walk in the park. She saw something very interesting - a big tree! Molly wanted to climb it so badly that she could see all of its branches.

(I also checked a few other examples, also with Falcon Mamba 7B Instruct - it constistently outputs text unrelated to the prompt.)

@JohannesGaessler
Copy link
Collaborator

Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version.

Generally speaking the outputs will always change if the code is changed due to differences in floating point rounding error. And due to the numerical instability of neural networks there is no guaranteed upper bound for how much these rounding errors will blow up for end-to-end evaluation. I think the best way to check for these things is the llama-perplexity binary. Check whether there is a significant change for various batch sizes.

@jploski
Copy link
Contributor

jploski commented Aug 26, 2024

Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version.

Generally speaking the outputs will always change if the code is changed due to differences in floating point rounding error. And due to the numerical instability of neural networks there is no guaranteed upper bound for how much these rounding errors will blow up for end-to-end evaluation. I think the best way to check for these things is the llama-perplexity binary. Check whether there is a significant change for various batch sizes.

Good point, but here the change in output is so striking that it is visible with the naked eye...

Comment on lines 5 to 10
const float * src0, const float * src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1,
float * dst,
const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add __restrict__ to the pointers, see #2140 .

Comment on lines 13 to 15
const int tid = threadIdx.x;
const int i2 = blockIdx.x;
const int i3 = threadIdx.y;
Copy link
Collaborator

@JohannesGaessler JohannesGaessler Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be aware that this is not a noop. There are special, shared registers for e.g. threadIdx.x and you are taking that data and moving it to regular registers. IIRC correctly the regular registers are slightly faster to access but it will also increase register pressure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.I believe that most GPU registers are more than sufficient to meet the requirements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I worded my previous post poorly. What I meant is that the regular registers are slightly faster (but there is only a limited amount of them).

Comment on lines 35 to 36
#pragma unroll
for (int i1 = 0; i1 < ir; ++i1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir is not known at compile time so this loop cannot actually be unrolled, same for the other loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

float sumf = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Copy link
Collaborator

@compilade compilade Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm missing something the memory access pattern here is bad with each thread accessing completely different data. You will achieve orders of magnitude higher memory bandwidth by accessing the data in a coalesced manner.

Thx.Current memory access pattern is more suitable for CPUs. I'm thinking about ways to address this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

Good Idea,I am currently testing according to your method.

Copy link
Contributor Author

@piDack piDack Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this could use the fact that this is operating on a self-overlapping view, so advancing with i2 shifts the view by one column.

In practice, for Mamba (and Mamba-2) models, nc is always 4, which might help with unrolling.

To coalesce memory accesses (at least for large prompts), I guess each warp could operate on WARP_SIZE/nc steps at a time over i2, assuming the WARP_SIZE is a multiple of 4 (is that always the case?), but this might need special handling of cases where i2 is not evenly divided by that.

I don't have much experience with CUDA (yet), so this might be misleading, but hopefully still helps.

I’ve found a simple implementation for ssm_conv that can coalesce memory accesses,can optimize the 2x performance,and I’ve already submitted the PR!For the ssm_scan, I'm feeling at a loss for optimization ideas.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling at a loss for optimization ideas.

Use one warp per iteration of nc with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum and have the first thread in the warp write back the result.

Copy link
Contributor Author

@piDack piDack Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha ha, I once also thought about using the wrap level api to calculate the sum. However, after taking a closer look, I realized that these additions are for the sum of single thread registers, not sum between thread in block. Therefore, wrap_reduce_sum might not be applicable here.Thx u review.if you have any other suggestions or better ideas, please feel free to share them. Your input is greatly appreciated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments for ssm_conv.cu largely apply here as well.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs labels Aug 26, 2024
@ggerganov
Copy link
Owner

Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version.

Generally speaking the outputs will always change if the code is changed due to differences in floating point rounding error. And due to the numerical instability of neural networks there is no guaranteed upper bound for how much these rounding errors will blow up for end-to-end evaluation. I think the best way to check for these things is the llama-perplexity binary. Check whether there is a significant change for various batch sizes.

Good point, but here the change in output is so striking that it is visible with the naked eye...

Note that GGML_CUDA=1 make -j tests && ./tests/test-backend-ops -o SSM_SCAN is failing using the latest commit on this branch, so definitely a bug has been introduced at some point:

Backend 2/2 (CUDA0)
  Backend name: CUDA0
  SSM_SCAN(type=f32,d_state=16,d_inner=1536,n_seq_tokens=7,n_seqs=2): [SSM_SCAN] NMSE = 161.401030491 > 0.000000100 FAIL
  1415/1416 tests passed
  Backend CUDA0: FAIL

The test passes on fae826f

@piDack
Copy link
Contributor Author

piDack commented Aug 27, 2024

Thanks for the quick work! Itsounds like a great improvement in terms of performance! But there seems to be some new bug introduced by the changes. The output is no longer the same as in my original version.

Generally speaking the outputs will always change if the code is changed due to differences in floating point rounding error. And due to the numerical instability of neural networks there is no guaranteed upper bound for how much these rounding errors will blow up for end-to-end evaluation. I think the best way to check for these things is the llama-perplexity binary. Check whether there is a significant change for various batch sizes.

Good point, but here the change in output is so striking that it is visible with the naked eye...

Note that GGML_CUDA=1 make -j tests && ./tests/test-backend-ops -o SSM_SCAN is failing using the latest commit on this branch, so definitely a bug has been introduced at some point:

Backend 2/2 (CUDA0)
  Backend name: CUDA0
  SSM_SCAN(type=f32,d_state=16,d_inner=1536,n_seq_tokens=7,n_seqs=2): [SSM_SCAN] NMSE = 161.401030491 > 0.000000100 FAIL
  1415/1416 tests passed
  Backend CUDA0: FAIL

The test passes on fae826f

Hi,I fix the numerical err on new commit and done some modifications according to @JohannesGaessler ’s opinion, but there is still a lot of room for performance improvement(like:memory coalesced).The SCAN op performance now is 0.53GB/s.Only slightly better than @jploski ’s version. I will attempt to optimize these ops on the current basis, but which may take a long time.

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
Testing 2 backends

Backend 1/2 (CPU)
  Skipping CPU backend
Backend 2/2 (CUDA0)
  Backend name: CUDA0
  SSM_SCAN(type=f32,d_state=16,d_inner=1536,n_seq_tokens=7,n_seqs=2): OK
  1416/1416 tests passed
  Backend CUDA0: OK

2/2 backends passed
OK

@piDack piDack requested a review from compilade August 27, 2024 12:46
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should mention that I know nothing about Mamba in terms of architecture or correctness, I was only commenting on the CUDA code in terms of optimization since I'm probably the contributor with the single most experience in that area.

ggml/src/ggml-cuda/ssm_conv.cu Outdated Show resolved Hide resolved
float sumf = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling at a loss for optimization ideas.

Use one warp per iteration of nc with each thread calculating a partial sum, then combine the partial sums via warp_reduce_sum and have the first thread in the warp write back the result.

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

I should mention that I know nothing about Mamba in terms of architecture or correctness, I was only commenting on the CUDA code in terms of optimization since I'm probably the contributor with the single most experience in that area.

Thank u for u valuable input, especially for your expert guidance on CUDA code optimization.

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

@jploski @ggerganov @compilade Hi guys, do u have any comments on this PR? I can’t see any areas for optimization at the moment, and the unit tests are passing.Haha,I’ve discovered that my implementation is very similar to the Meta version.The current version might be the fastest to run on NV GPUs

@jploski
Copy link
Contributor

jploski commented Aug 29, 2024

@jploski @ggerganov @compilade Hi guys, do u have any comments on this PR? I can’t see any areas for optimization at the moment, and the unit tests are passing.

Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet).

What I'm wondering about is the following quote from the original Mamba paper:

"Concretely, instead of preparing the scan input (𝑨, 𝑩) of size (B, L, D, N) in GPU HBM (high-bandwidth memory), we load
the SSM parameters (Δ, 𝑨, 𝑩, 𝑪) directly from slow HBM to fast SRAM, perform the discretization and recurrence in SRAM,
and then write the final outputs of size (B, L, D) back to HBM."

Does the current implementation already take advantage of it? As far as I understand, the distinction between SRAM vs. HBM refers to utilizing shared memory in CUDA kernels. What I do not understand is whether the "fast" path in Python does anything explicit to manage this memory, or if it is already a side effect of the base algorithm.

One thing which I also haven't done yet is to measure the speed of generation of the "fast" Python implementation and compare to this PR.

I recall benchmarks showing that Mamba's generation is slower than transformer's for small context lengths, and only surpasses it after some threshold (because it's constant time per token, unlike in transformer). So as a smoke test, it may also be good to confirm that it holds for our implementation.

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

Concretely, instead of preparing the scan input (𝑨, 𝑩) of size (B, L, D, N) in GPU HBM (high-bandwidth memory), we load
the SSM parameters (Δ, 𝑨, 𝑩, 𝑪) directly from slow HBM to fast SRAM, perform the discretization and recurrence in SRAM,
and then write the final outputs of size (B, L, D) back to HBM."

This point has not been utilized. The implementation in paper employs a tile algorithm similar to flash-attention, which is quite complex overall. I think it will be a substantial project and should be approached step by step

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

One thing which I also haven't done yet is to measure the speed of generation of the "fast" Python implementation and compare to this PR.

If it’s the implementation from the paper, it should be faster because its operators are specifically optimized for CUDA. However, it has certain requirements for the CUDA version. As for minimamba, personally, I feel it should be faster than that.I will test.

Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet).

Fun,Is the output of your version the same as that of the CPU(tinystories model)?

@jploski
Copy link
Contributor

jploski commented Aug 29, 2024

Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet).

Fun,Is the output of your version the same as that of the CPU(tinystories model)?

No, it also diverges after some initial tokens.

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

correctness test

./build/bin/llama-batched -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -p "Hello, my name is" -np 16 -n 32

result

llama_model_loader: loaded meta data with 27 key-value pairs and 242 tensors from /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf (version GGUF V3 (latest))

main: n_predict = 32, n_ctx = 448, n_batch = 32, n_parallel = 16, n_kv_req = 437

Hello, my name is

main: generating 16 sequences ...

main: stream 0 finished at n_cur = 32
main: stream 1 finished at n_cur = 32
main: stream 2 finished at n_cur = 32
main: stream 3 finished at n_cur = 32
main: stream 4 finished at n_cur = 32
main: stream 5 finished at n_cur = 32
main: stream 6 finished at n_cur = 32
main: stream 7 finished at n_cur = 32
main: stream 8 finished at n_cur = 32
main: stream 9 finished at n_cur = 32
main: stream 10 finished at n_cur = 32
main: stream 11 finished at n_cur = 32
main: stream 12 finished at n_cur = 32
main: stream 13 finished at n_cur = 32
main: stream 14 finished at n_cur = 32
main: stream 15 finished at n_cur = 32

sequence 0:

Hello, my name is James.

I'm here to find out where to get a
 spot of a black-th/ and all the other details

sequence 1:

Hello, my name is James. I have a question about my home. I have a question about my home. I have a question about my home.


sequence 2:

Hello, my name is Misha, and I'm a newbie in the world of programming. I'm looking for a good and professional way to learn about

sequence 3:

Hello, my name is Niall and I'm a newbie. I've been trying to use the "Add to Favorites" feature in my

sequence 4:

Hello, my name is Amit and I am a big fan of your work. I am a student of yours and I am looking for a job. I

sequence 5:

Hello, my name is Ronaldo.

I am a young man who has been in the game for a long time. I am a big fan

sequence 6:

Hello, my name is Alex. I'm a writer and editor at the Daily Sun and I'm a member of the Writers Guild of America. I'm a

sequence 7:

Hello, my name is Alisa. I'm a student in the University of the West Indies. I'm looking for a job that I can be a part

sequence 8:

Hello, my name is Alexandra and I'm a photographer who lives in the UK. I'm looking to get back into photography and I'm looking for a

sequence 9:

Hello, my name is Tom. I'm a
famous musician, and I'm a writer for a newspaper. I'm a
professional musician and a

sequence 10:

Hello, my name is Gabe. I am a newbie to Android. I am a newbie to Android. I am new to Android. I am

sequence 11:

Hello, my name is K. I'm an old, but not really old.

I have no problem with this.

I have a problem

sequence 12:

Hello, my name is David.

I'm a very, uh, friendly, kindordial personordial.

I'm looking cryptocurrement,

sequence 13:

Hello, my name is Slim.

I'm just a little bit of a mystery.

I'm a little bit of a girl.


sequence 14:

Hello, my name is John.

I'm a little bit of a mess with the
liesossip.

I am a very, very,

sequence 15:

Hello, my name is K.
I am not

main: decoded 432 tokens in 0.95 s, speed: 455.42 t/s

llama_print_timings:        load time =    1836.86 ms
llama_print_timings:      sample time =      19.33 ms /   448 runs   (    0.04 ms per token, 23177.61 tokens per second)
llama_print_timings: prompt eval time =    2082.17 ms /   437 tokens (    4.76 ms per token,   209.88 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =    2784.86 ms /   438 tokens

GPU version perplexity test

./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99
perplexity: tokenizing the input ..
perplexity: tokenization took 1109.9 ms
perplexity: calculating perplexity over 560 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 0.53 seconds per pass - ETA 1.22 minutes
....
Final estimate: PPL = 22.5804 +/- 0.17868

CPU version perplexity test

Final estimate: PPL = 22.5804 +/- 0.17868

@piDack
Copy link
Contributor Author

piDack commented Aug 29, 2024

Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet).

Fun,Is the output of your version the same as that of the CPU(tinystories model)?

No, it also diverges after some initial tokens.

I had done perplexity test, the PPL value appears to be the same as the CPU version. Therefore, I believe this ops is correct

@jploski
Copy link
Contributor

jploski commented Aug 29, 2024

Using falcon-mamba-7b-instruct-F16 I see 18% performance improvement vs. my version (~12.3 tokens/sec vs. 14.5 tokens/sec). The output of each version is different (and also, in tests of the tinystories model, different from CPU after ~20 or so tokens), but coherent (haven't checked perplexity of both vs. Python-generated text yet).

Fun,Is the output of your version the same as that of the CPU(tinystories model)?

No, it also diverges after some initial tokens.

I had done perplexity test, the PPL value appears to be the same as the CPU version. Therefore, I believe this ops is correct

I found it a bit sus that your PPL values matched exactly, so I ran my own test (./llama-perplexity -m /mnt/f2fs/mamba/tinystories_f32.gguf -f scripts/wikitext-2-raw/wiki.test.raw -ngl ...). The results:

CPU: Final estimate: PPL = 957.3512 +/- 9.12936
GPU: Final estimate: PPL = 957.3612 +/- 9.12956

So yes, I think this is ready to be unleashed on the unsuspecting world, despite of what our very friendly kindordial personordials might say...

@JohannesGaessler
Copy link
Collaborator

For which batch sizes are these new kernels used? llama-perplexity by default uses a batch size of 512 so you may not notice any problems with e.g. kernels that are only used for batch size 1 (token generation). For those you should explicitly set a batch size. Also check batch sizes that are not powers of 2 for potential out-of-bounds accesses. Also note that if CUDA is available, it will be used even for CPU layers for batch sizes >= 32 so you have to compile completely without CUDA or any other GPU support for a proper comparison.

CPU: Final estimate: PPL = 957.3512 +/- 9.12936
GPU: Final estimate: PPL = 957.3612 +/- 9.12956

Unless the vocabulary size for the model is absolutely gargantuan those values look to me like a bug. A value < 10 should be expected for example for LLaMA 3. I never tested TinyStories but I can't imagine that it's that bad at predicting Wikitext tokens.

Final estimate: PPL = 22.5804 +/- 0.17868

That value also looks suspiciously large to me.

@jploski
Copy link
Contributor

jploski commented Aug 29, 2024

For which batch sizes are these new kernels used? llama-perplexity by default uses a batch size of 512 so you may not notice any problems with e.g. kernels that are only used for batch size 1 (token generation). For those you should explicitly set a batch size. Also check batch sizes that are not powers of 2 for potential out-of-bounds accesses. Also note that if CUDA is available, it will be used even for CPU layers for batch sizes >= 32 so you have to compile completely without CUDA or any other GPU support for a proper comparison.

CPU: Final estimate: PPL = 957.3512 +/- 9.12936
GPU: Final estimate: PPL = 957.3612 +/- 9.12956

Unless the vocabulary size for the model is absolutely gargantuan those values look to me like a bug. A value < 10 should be expected for example for LLaMA 3. I never tested TinyStories but I can't imagine that it's that bad at predicting Wikitext tokens.

The tinystories model I'm using here is one I trained myself off mamba-130m-hf, which is a small model, and the TinyStories dataset is intentionally very restricted in content type. The outputs my model produces are pretty bad overall (possibly not just because it's just a "silly" dataset, but the model is also undertrained on it). So I wouldn't be paying much attention to the absolute value of PPL for wiki. The tokenizer vocabulary here is the same as mamba-130-hf.

In that test I was mostly interested whether the CPU perplexity comes out exactly the same as GPU as in @piDack's experiment.

A value < 10 should be expected for example for LLaMA 3.

By LLaMA 3 are you referring to a big model? The Mamba models we both used for computing PPL here are both just 130M parameters, so the PPL=22 estimate of @piDack may be all right after all.

@piDack
Copy link
Contributor Author

piDack commented Aug 30, 2024

For which batch sizes are these new kernels used? llama-perplexity by default uses a batch size of 512 so you may not notice any problems with e.g. kernels that are only used for batch size 1 (token generation). For those you should explicitly set a batch size. Also check batch sizes that are not powers of 2 for potential out-of-bounds accesses.

Once again, thks 4 u code review. I tested the perplexity with different batch sizes with -ub (which default value is 512), and the results were very similar.So, I believe the correctness can be assured.

GPU

# batch = 3
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 3
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 5
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 19
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 31
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99 -ub 32
./build/bin/llama-perplexity -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ngl 99
batch size 3 5 19 31 32 111 222 333 default
PPL 22.5772 +/- 0.17866 22.8276 +/- 0.18110 22.5804 +/- 0.17867 22.5804 +/- 0.17867 22.5805 +/- 0.1786 22.5804 +/- 0.17867 23.5505 +/- 0.18720 22.5823 +/- 0.17870 22.5804 +/- 0.17868

CPU

recompile use openmp as backend.And use nvidia-smi to ensure there are no gpu mem used

make clean
make -j llama-perplexity
./llama-perplexity  -m /mnt/yhl/model/mamba-130m/mamba-130M-F16.gguf -f /mnt/yhl/tmp/data/wikitext-2/wikitext-2-raw/wiki.test.raw -ub 31

Final estimate: PPL = 22.5590 +/- 0.17848

Meta

form #8546

PPL = 25.0894 +/- 0.1855

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Aug 30, 2024
@A3shTnT
Copy link
Contributor

A3shTnT commented Nov 27, 2024

For #6758 . I had Optimized the performance by 10 times based @jploski pr https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda .Under the current test case, the performance of conv has improved from 0.62 GB/s to 6.47 GB/s. The performance of scan has increased from 0.30 GB/s to 2.86 GB/s

hello,i am wondering how to test the speed, what tools to use to get the value 2.86GB/s?

@jploski
Copy link
Contributor

jploski commented Nov 27, 2024

For #6758 . I had Optimized the performance by 10 times based @jploski pr https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda .Under the current test case, the performance of conv has improved from 0.62 GB/s to 6.47 GB/s. The performance of scan has increased from 0.30 GB/s to 2.86 GB/s

hello,i am wondering how to test the speed, what tools to use to get the value 2.86GB/s?

I can't comment on that particular metric.

Generally you can profile individual kernels with NVIDIA Nsight Compute (the command ncu or ncu-ui). You can simply let this tool run the usual llama.cpp command line to generate tokens while specifying upfront the kernel you want to collect performance data about (the "Filter" tab in ncu-ui). However, note that for optimizations to make practical sense you should also keep an eye on how much that kernel's execution contributes to the overall execution time, (in particular the per-token memory transfers between GPU and host might dominate the execution time).

@A3shTnT
Copy link
Contributor

A3shTnT commented Nov 27, 2024

For #6758 . I had Optimized the performance by 10 times based @jploski pr https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda .Under the current test case, the performance of conv has improved from 0.62 GB/s to 6.47 GB/s. The performance of scan has increased from 0.30 GB/s to 2.86 GB/s

hello,i am wondering how to test the speed, what tools to use to get the value 2.86GB/s?

I can't comment on that particular metric.

Generally you can profile individual kernels with NVIDIA Nsight Compute (the command ncu or ncu-ui). You can simply let this tool run the usual llama.cpp command line to generate tokens while specifying upfront the kernel you want to collect performance data about (the "Filter" tab in ncu-ui). However, note that for optimizations to make practical sense you should also keep an eye on how much that kernel's execution contributes to the overall execution time, (in particular the per-token memory transfers between GPU and host might dominate the execution time).

got it, thank you, i will have a try.

@A3shTnT A3shTnT mentioned this pull request Nov 28, 2024
4 tasks
@jploski
Copy link
Contributor

jploski commented Dec 2, 2024

I think this PR has been superseded by PR #10558.

@piDack piDack closed this Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants