Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance update on the backward split kernel #127

Open
wants to merge 24 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
name: AMD Perf Kernel Tests

on:
workflow_dispatch:
pull_request:
branches: [main_perf]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: true

permissions: read-all

jobs:
Integration-Tests-AMD:
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: [linux-mi300-gpu-1, gfx1100]
fail-fast: false # disables failing the entire job when one matrix entry fails
container:
image: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Show Device Info
run: |
rocminfo | grep gfx

- name: Uninstall Triton
run : |
pip uninstall -y triton
rm -rf ~/.triton
rm -rf ./triton/python/build

- name: Install Triton
run: |
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install ninja cmake wheel pybind11 # build-time dependencies
pip install matplotlib pandas pytest # triton bench dependencies
pip install --verbose --no-build-isolation ./python
cd ..

- name: Show Triton version
run: |
pip show triton

- name: Build
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install

- name: Flash Attention Tests using Pytorch reference implementation
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
FLASH_ATTENTION_TRITON_AMD_REF=1 pytest tests/test_flash_attn_triton_amd.py

# CDNA Tests
- name: Flash Attention CDNA Tests
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py

# FIXME: run the full suite
- name: AMD Tests
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest -v -s flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_fp8 flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_varlen_fp8

- name: AMD Bench
if: matrix.runner == 'linux-mi300-gpu-1'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py

# RDNA Tests
- name: Flash Attention RDNA Tests
if: matrix.runner == 'gfx1100'
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

# NOTE: this exceeds 6 hrs on "gfx1100" so we are testing a subset of the tests. The full suite is run on a CDNA machine.
pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output
15 changes: 14 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,17 @@ var/
.idea/

# Dev
venv
venv

# AMD
scripts
csrc/flash_attn_ck
.eggs
*.log
core.*
gpucore.*
*.csv
*.png
*.html
*.json
*.txt
65 changes: 63 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports:
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

### AMD ROCm Support
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.

**Requirements:**
- ROCm 6.0 and above.
Expand All @@ -121,11 +121,72 @@ We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.

FlashAttention-2 with ROCm currently supports:
#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.

#### Triton Backend
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes
5) Multi and grouped query attention
6) Dropout
7) Rotary embeddings

These features are supported in Fwd for now. We will add them to backward soon.
2) ALiBi and matrix bias

These features are in development
1) Paged Attention
2) Sliding Window
5) Performance Improvements

##### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn_triton_amd.py
```

###### Docker
We have also created a Dockerfile.

To build the docker file
```
cd flash_attn/flash_attn_triton_amd
docker build -t fa_triton .
```

To run the docker image
```
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
```
Inside the docker, it should open to the flash attention repo with everything installed. You can run the following command to test things.
```
pytest tests/test_flash_attn_triton_amd.py
```


## How to use FlashAttention

Expand Down
71 changes: 62 additions & 9 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import torch
import torch.nn as nn
import os

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

# isort: on

Expand Down Expand Up @@ -85,10 +90,14 @@ def _flash_attn_forward(
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
return_softmax: bool
return_softmax: bool,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
Expand All @@ -102,6 +111,10 @@ def _flash_attn_forward(
softcap,
return_softmax,
None,
descale_q,
descale_k,
descale_v,
descale_p
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -159,9 +172,13 @@ def _flash_attn_varlen_forward(
block_table: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
q,
k,
v,
Expand All @@ -183,6 +200,10 @@ def _flash_attn_varlen_forward(
softcap,
return_softmax,
None,
descale_q,
descale_k,
descale_v,
descale_p
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand Down Expand Up @@ -260,7 +281,7 @@ def _flash_attn_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
) = flash_attn_gpu.bwd(
dout,
q,
k,
Expand Down Expand Up @@ -356,7 +377,7 @@ def _flash_attn_varlen_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
) = flash_attn_gpu.varlen_bwd(
dout,
q,
k,
Expand Down Expand Up @@ -799,6 +820,10 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
descale_q,
descale_k,
descale_v,
descale_p
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -819,6 +844,10 @@ def forward(
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -862,7 +891,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -885,6 +914,10 @@ def forward(
deterministic,
return_softmax,
block_table,
descale_q,
descale_k,
descale_v,
descale_p
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -910,6 +943,10 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
descale_p=descale_p
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -961,7 +998,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -1111,6 +1148,10 @@ def flash_attn_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1172,6 +1213,10 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
descale_q,
descale_k,
descale_v,
descale_p
)


Expand Down Expand Up @@ -1348,6 +1393,10 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1421,6 +1470,10 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
descale_q,
descale_k,
descale_v,
descale_p
)


Expand Down Expand Up @@ -1544,7 +1597,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
Expand Down
Loading