Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend](cuda): faster uncontiguous concat #10760

Merged
merged 4 commits into from
Dec 12, 2024

Conversation

A3shTnT
Copy link
Contributor

@A3shTnT A3shTnT commented Dec 10, 2024

Faster implementation of uncontiguous concat on cuda.

  1. the Debug and Release build ci test is passed.
  2. test-backend-op is passed, I also add one performance test, but since it may not be a real example, I just post the result in the experiment.
  3. the first uncontiguous concat kernel in deepseek-v2 is faster than master. However, the total performence is lower than master, which makes me a little confused. Due to variations in the running results of llama bench, anybody can run test and discuss the result.

Performance experiment:
back-end-ops:

test case
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {512, 1024, 5, 1}, 1024, 0, 1));

my implementation:

Backend 1/2: CUDA0
  Device description: NVIDIA GeForce RTX 3060 Laptop GPU
  Device memory: 6143 MB (5122 MB free)

  CONCAT(type=f32,ne_a=[512,1024,5,1],ne_b_d=1024,dim=0,v=1):                   5004 runs -   211.03 us/run -   120830 kB/run -  546.04 GB/s
  Backend CUDA0: OK

master:

  Device 0: NVIDIA GeForce RTX 3060 Laptop GPU, compute capability 8.6, VMM: yes
Testing 2 devices

Backend 1/2: CUDA0
  Device description: NVIDIA GeForce RTX 3060 Laptop GPU
  Device memory: 6143 MB (5122 MB free)

  CONCAT(type=f32,ne_a=[512,1024,5,1],ne_b_d=1024,dim=0,v=1):                   3336 runs -   322.92 us/run -   120830 kB/run -  356.84 GB/s
  Backend CUDA0: OK

llama-bench:

command:
~/program/forked/llama.cpp/build/bin$ ./llama-bench -m ~/program/forked/DeepSeek-V2-Lite-Chat.IQ1_S.gguf

the first uncontiguous concat kernel:
my implementation:
image

master:
image

the overall result:
my implementation:

model size params backend ngl test t/s
deepseek2 16B IQ1_S - 1.5625 bpw 4.65 GiB 15.71 B CUDA 99 pp512 768.63 ± 8.81
deepseek2 16B IQ1_S - 1.5625 bpw 4.65 GiB 15.71 B CUDA 99 tg128 21.84 ± 3.75

master:

model size params backend ngl test t/s
deepseek2 16B IQ1_S - 1.5625 bpw 4.65 GiB 15.71 B CUDA 99 pp512 804.11 ± 4.73
deepseek2 16B IQ1_S - 1.5625 bpw 4.65 GiB 15.71 B CUDA 99 tg128 20.92 ± 3.66

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 10, 2024
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check NVIDIA NSight Systems and look at the runtime for different kernel configurations. It may be that the variant you looked at with NSight Compute is faster but another one has become slower (though I wouldn't know why).

Comment on lines 220 to 221
default:
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default:
break;
default:
GGML_ABORT("fatal error");
break;

if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
if /*constexpr*/ (dim == 0) {
Copy link
Collaborator

@JohannesGaessler JohannesGaessler Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a static_assert for dim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that if constexpr should be available now that we are using C++17, no need to keep it commented.

Copy link
Contributor Author

@A3shTnT A3shTnT Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that if constexpr should be available now that we are using C++17, no need to keep it commented.

Thank you for your advice, I removed the comments. I also found that the number and content of instructions generated by the compiler did not change with or without constexpr. Perhaps nvcc has already optimized for this point

@A3shTnT
Copy link
Contributor Author

A3shTnT commented Dec 11, 2024

Check NVIDIA NSight Systems and look at the runtime for different kernel configurations. It may be that the variant you looked at with NSight Compute is faster but another one has become slower (though I wouldn't know why).

I'll give it a try.

@A3shTnT
Copy link
Contributor Author

A3shTnT commented Dec 12, 2024

Check NVIDIA NSight Systems and look at the runtime for different kernel configurations. It may be that the variant you looked at with NSight Compute is faster but another one has become slower (though I wouldn't know why).

I'll give it a try.

Since this computer cannot run a model as large as Deepseek, I used the existing MAMBA model, but the problem should be quite similar.
It seems that the uncontiguous concat operation has a relatively small proportion of execution time in the entire model, and the performance improvement is difficult to see. Perhaps updates targeting small operators may not achieve significant acceleration effects at the model level, and the impact of fluctuations in runtime function execution time may be even greater.I am not familiar with runtime, this is just my opinion.

Kernel execution

my implementation:

Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
31.4 1508066774 16031 94071.9 59026.0 57390 955443 171105.2 void mul_mat_vec
... ... ... .... ... ... ... ... ...
1.1 51516703 31056 1658.8 963.0 898 74872 7035.6 silu_f32
... ... ... .... ... ... ... ... ...
0.4 20815533 144 144552.3 139977.0 136432 206558 16592.1 void concat_f32_non_cont
... ... ... .... ... ... ... ... ...

pr #10558(It can be considered as master, because changes are not related to concat):

Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
31.2 1503470079 16031 93785.2 58678.0 57521 956623 170486.9 void mul_mat_vec
... ... ... .... ... ... ... ... ...
1.2 55916400 144 388308.3 387963.0 380733 399661 3754.4 concat_f32_non_cont
1.1 51843144 31056 1669.3 965.0 899 74970 7049.0 silu_f32
... ... ... .... ... ... ... ... ...

Runtime function
my implementation:

Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
38.2 3830368168 6718 570165.0 17748.5 397 324782389 8298624.3 cudaStreamSynchronize
25.2 2526634352 2027 1246489.6 15925.0 1772 294406598 8298396.3 cuLibraryLoadData
25.0 2507314668 360955 6946.3 5192.0 2242 3443261 15761.5 cudaLaunchKernel
6.2 617263233 3 205754411.0 480931.0 180886 616601416 355803975.0 cudaMemGetInfo
4.6 462075207 34389 13436.7 6838.0 3076 4687424 264276.4 cudaMemcpyAsync

pr #10558

Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
43.6 3795485106 6718 564972.5 15840.5 342 313831670 8389839.7 cudaStreamSynchronize
32.3 2811759210 360955 7789.8 4643.0 1991 4257846 18166.3 cudaLaunchKernel
14.8 1284635317 2027 633761.9 5042.0 1289 203751445 5059546.5 cuLibraryLoadData
5.3 465147879 34389 13526.1 5898.0 2665 22423451 136640.3 cudaMemcpyAsync
3.1 271980756 3 90660252.0 453325.0 181184 271346247 156478720.9 cudaMemGetInfo

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this computer cannot run a model as large as Deepseek, I used the existing MAMBA model, but the problem should be quite similar.

You can profile the code from the command line with something like nsys profile ./llama-bench, then copy the resulting file to your desktop where you can analyze it. But these numbers already look very good so I would be fine with merging this as-is.

@A3shTnT
Copy link
Contributor Author

A3shTnT commented Dec 12, 2024

Since this computer cannot run a model as large as Deepseek, I used the existing MAMBA model, but the problem should be quite similar.

You can profile the code from the command line with something like nsys profile ./llama-bench, then copy the resulting file to your desktop where you can analyze it. But these numbers already look very good so I would be fine with merging this as-is.

The profile information was obtained through the nsight system, nsys profile, and then the summary table obtained through the nsys stats . The Nsight System is indeed a great tool for analyzing from the whole model overview. Thank you for your advice.

@JohannesGaessler JohannesGaessler merged commit 8faa1d4 into ggerganov:master Dec 12, 2024
47 checks passed
netrunnereve pushed a commit to netrunnereve/llama.cpp that referenced this pull request Dec 16, 2024
* faster uncontiguous concat

* Use a lambda to avoid code duplication

Co-authored-by: Diego Devesa <[email protected]>

* Update ggml/src/ggml-cuda/concat.cu

* add constexpr  and static assert

---------

Co-authored-by: Diego Devesa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants