Skip to content

Commit

Permalink
[FA2] tiling-qkv F32/F16 + swizzle q/qk/qkv🎉 (#213)
Browse files Browse the repository at this point in the history
* Update flash_attn_mma_tiling_qkv.cu

* Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu

* Update flash_attn_mma_tiling_qkv_swizzle_q.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma_tiling_qkv_swizzle_qk.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma_tiling_qkv_swizzle_qkv.cu

* Update flash_attn_mma_tiling_qkv_swizzle_q.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qk.cu

* Update flash_attn_mma_tiling_qkv_swizzle_q.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qkv.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma_share_kv.cu

* Update flash_attn_mma_share_kv_F32F16F16F32.cu

* Update flash_attn_mma_split_kv.cu

* Update flash_attn_mma_split_q.cu

* Update flash_attn_mma_tiling_qk.cu

* Update flash_attn_mma_tiling_qk_F32F16F16F32.cu

* Update flash_attn_mma_tiling_qkv.cu

* Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu

* Create flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu

* Create flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu

* Create flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu

* Update README.md

* Update README.md

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update flash_attn_mma_tiling_qkv_swizzle_q.cu

* Update flash_attn_mma_tiling_qkv_swizzle_q.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qk.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qk.cu

* Update flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma_tiling_qkv_F32F16F16F32.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu

* Update flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu

* Update flash_attn_mma.py

* Update flash_attn.cc

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Jan 3, 2025
1 parent c687988 commit 82f1d04
Show file tree
Hide file tree
Showing 18 changed files with 6,509 additions and 156 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QKV Fine-grained Tiling**|
|✔️|✔️|✔️|✔️|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **92 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)

|Algorithm| (B,H,N,D) | RTX 3080 Laptop | L20 | RTX 4090 |
|:---:|:---:|:---:|:---:|:---:|
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
|split-q+share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
|share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
|share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
|tiling-qkv+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**92 TFLOPS**|**157 TFLOPS**|
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |

The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
Expand Down Expand Up @@ -369,9 +369,12 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [flash_attn_mma...tiling_qk_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qk_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma...tiling_qkv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma...tiling_qkv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma...tiling_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{q}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{qk}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma...tiling_qkv_swizzle{qkv}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn...tiling_qkv_swizzle{q}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn...tiling_qkv_swizzle{qk}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn...tiling_qkv_swizzle{qkv}{f32}*](./kernels/flash-attn/mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv_F32F16F16F32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|

**rr**: means reduce registers usage (for `d>128`); **f32**: means MMA accumulate with FP32 dtype, otherwise, FP16. softmax Acc dtype is always be FP32 for high precision; **swizzle**: now, only support smem swizzle for MMA.

Expand Down
10 changes: 5 additions & 5 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
|✔️|✔️|✔️|✔️|

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (MMA Acc F16, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192, D <= 64)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (MMA Acc F16/F32, softmax Acc F32 vs FA2 MMA/softmax Acc F32, 👇Benchmark)

|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
|:---:|:---:|:---:|:---:|:---:|
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
|split-q+share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
|share-qkv+stage2|(1,8,8192,64)|**55 TFLOPS**|99 TFLOPS|**221 TFLOPS**|
|FlashAttention-2|(1,48,8192,64)|37 TFLOPS|109 TFLOPS|163 TFLOPS|
|split-q+share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
|share-qkv+stage2|(1,48,8192,64)|**48 TFLOPS**|107 TFLOPS|**224 TFLOPS**|
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
|split-q+tiling-qkv+stage2|(1,48,8192,512)|**23 TFLOPS**|**90 TFLOPS**|**135 TFLOPS**|
|tiling-qkv+swizzle-qk+stage2|(1,48,8192,512)|**23 TFLOPS**|**92 TFLOPS**|**157 TFLOPS**|
|Precision Errors vs FA2/SDPA| / | max: < ~1e-3 | min: ~0.0 | mean: < ~1e-5 |

For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **90 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) method can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + Fully QKV Fine-grained Tiling](#mma-tiling-qkv) method can achieve **92 TFLOPS (D=512)** that almost **~1.6x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~

## 📖 Contents

Expand Down
Loading

0 comments on commit 82f1d04

Please sign in to comment.