-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance Tuning for Q4_K matmul CUDA kernel #8136
base: master
Are you sure you want to change the base?
Conversation
With
However, without
For this format I think mmvq should be automatically used for compute capability >= 6.1. |
Similar results for RTX 2060: LLAMA_CUDA_FORCE_DMMV=1 LLAMA_CUDA=1 ./scripts/compare-commits.sh master pr/8136 -m ./models/llama-7b-v2/ggml-model-q4_k.gguf -p 0 -ngl 99
LLAMA_CUDA=1 make -j && ./llama-bench -m ./models/llama-7b-v2/ggml-model-q4_k.gguf -p 0 -ngl 99 ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
build: 50732c0 (3236) |
I don't think that's the reason. In any case, I can confirm that the performance increases with this change and that most of the difference comes from the |
#define BLOCK_DIM_X 32 | ||
#define BLOCK_DIM_Y 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#define BLOCK_DIM_X 32 | |
#define BLOCK_DIM_Y 4 | |
#define Q4_K_BLOCK_DIM_X 32 | |
#define Q4_K_BLOCK_DIM_Y 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if the optimal value for the x dimension is 32, should that parameter simply be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any is fine with me - had them there as defines to easily change them and also reuse them across other kernels in my testing.
I left them in this PR for readability but happy to hard-code them or renaming them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I would just hard-code them because as it is they get multiplied with magic numbers anyways.
Unless I am completely misremembering the CUDA documentation definitely said that the
Assuming this can also be applied to other kernels this is definitely a very good find. |
@JohannesGaessler float4 issues a single instruction to read 128bit. In theory, the compiler should be capable of understanding this from the original code (w.o. having to use "unsafe" methods as reinterpret cast), but the compiler often is imperfect. @slaren My plan was to look at |
You are correct. A float4 is 4 floats in 4 registers and there is supposed for aligned vectorized loads (LD*.(8|16|32|64|128) with numbers being bits. Loads / Stores are split into transactions of sectors (32 bytes) with one sector being processed each clock. If a warp accesses multiple sectors they get serialized which increases MIO utilization. When simplifying things a bit memory efficiency can be considered as In this case 4 contiguous floats have been read with 4 LD instructions. Two contiguous warps have a byte offset of 16 bytes resulting in reading bytes 0-3, 16-19 within a sector. Thus memory efficiency was 8/32=0.25. Using float4 instead of float increased this metric to 1.0.
A full sweep with Nsight Compute through all kernels / kernel dimensions of the most important networks will reveal candidates where the optimization might help. FB utilization is the metric to look for most of the time for the llama.cpp kernels. A utilization >95% is perfect, >90% is pretty good, Everything below might be using memory inefficient and thus benefit from vectorized loads and/or reordering of memory accesses. |
Yes, mmvq is used more frequently. At this point, the dmmv kernels are only with GPUs that don't support dp4a, cc < 6.1. |
All NVIDIA GPUs starting with compute capability 6.1 have the
Basically the only kernels that are performance relevant are For MMVQ I've been thinking it would maybe be possible to just load the data contiguously and apply a bit mask instead of loading the data from 2/4 byte values that are not 100% contiguous.
I currently don't have access to an instance of NSight compute, MMVQ had ~90% "speed of light" memory utilization, I don't know FB specifically. |
This is not really the design, backends have the ability to change the layout of the tensor data. I can go into more detail if you think that could improve performance significantly, but essentially the backends can convert the data to whatever layout they want during the |
I previously made prototypes where I converted the data to struct of arrays layout but I was not able to get better performance for MMVQ (though it's always possible that I just did it wrong). For MMQ it would maybe be useful because for efficient use of asynchronous data loading you need 16 byte alignment. But right now the int8 tensor core utilization is only 35% so there are probably still other problems that would need to be fixed. And unless you drop support for partial offloading completely you would need to implement and compile two separate instances for loading data per quantization format so I'm thinking that changing the data layout is comparatively a lot of work for the potential benefits. |
I don't think this needs to be the case, during partial offloading the weights are also copied to VRAM using the backend interface, and they can be converted to a different layout during the copy. As long as it is only a change in layout and does not require any expensive computations, I don't think it would affect performance significantly. |
In my prototype I did the conversion via host->device |
Given that SM load is low for the kernels of interest instead of adding complexity to the codebase we can also add support for unaligned loads in CUDA:
is as good as it can get loading 32-bit from an unaligned address. Given that the offset is 2 byte aligned there is a 50% chance that only a single load is required and if the alignment is not given 2 loads have to be done anyway and the result has to be combined as well. https://godbolt.org/z/6zdT9osb3 has the code and SASS. The unaligned load logic essentially adds only 2 LOP3 instructions and that's it. When streaming contiguous data only a single additional load would be required for the whole stream.
The same can be done with 64-bit and 128-bit loads as long as there are free cycles in the ALU or FP unit to move around registers. |
Relative to the current master code that loads the data as 2 16 bit values I'm measuring a speedup of ~1%. Is there a reason why you're using inline PTX assembly instead of __byte_perm (which I think does the same thing)? |
#define BLOCK_DIM_X 32 | ||
#define BLOCK_DIM_Y 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I would just hard-code them because as it is they get multiplied with magic numbers anyways.
float4 y11 = *reinterpret_cast<const float4*>(y1+0); | ||
float4 y12 = *reinterpret_cast<const float4*>(y1+32); | ||
float4 y21 = *reinterpret_cast<const float4*>(y2+0); | ||
float4 y22 = *reinterpret_cast<const float4*>(y2+32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these could be made const
but it doesn't really matter.
Changes:
Performance
Using llama-bench I measured the end-to-end speedup
Device 0: NVIDIA RTX 6000 Ada Generation, compute capability 8.9, VMM: yes
BASLINE
FLOAT4
128 Threads per Block
FLOAT4 + 128 Threads per Block