diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml new file mode 100644 index 000000000..c41ac0d0b --- /dev/null +++ b/.github/workflows/amd_tests.yml @@ -0,0 +1,90 @@ +name: AMD Perf Kernel Tests + +on: + workflow_dispatch: + pull_request: + branches: [main_perf] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + Integration-Tests-AMD: + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [linux-mi300-gpu-1, gfx1100] + fail-fast: false # disables failing the entire job when one matrix entry fails + container: + image: rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Show Device Info + run: | + rocminfo | grep gfx + + - name: Uninstall Triton + run : | + pip uninstall -y triton + rm -rf ~/.triton + rm -rf ./triton/python/build + + - name: Install Triton + run: | + git clone https://github.com/triton-lang/triton + cd triton + git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 + pip install ninja cmake wheel pybind11 # build-time dependencies + pip install matplotlib pandas pytest # triton bench dependencies + pip install --verbose --no-build-isolation ./python + cd .. + + - name: Show Triton version + run: | + pip show triton + + - name: Build + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + python setup.py install + + - name: Flash Attention Tests using Pytorch reference implementation + if: matrix.runner == 'linux-mi300-gpu-1' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + FLASH_ATTENTION_TRITON_AMD_REF=1 pytest tests/test_flash_attn_triton_amd.py + + # CDNA Tests + - name: Flash Attention CDNA Tests + if: matrix.runner == 'linux-mi300-gpu-1' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + pytest tests/test_flash_attn_triton_amd.py + + # FIXME: run the full suite + - name: AMD Tests + if: matrix.runner == 'linux-mi300-gpu-1' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + pytest -v -s flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_fp8 flash_attn/flash_attn_triton_amd/test.py::test_op_prefill_varlen_fp8 + + - name: AMD Bench + if: matrix.runner == 'linux-mi300-gpu-1' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py + + # RDNA Tests + - name: Flash Attention RDNA Tests + if: matrix.runner == 'gfx1100' + run: | + export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + + # NOTE: this exceeds 6 hrs on "gfx1100" so we are testing a subset of the tests. The full suite is run on a CDNA machine. + pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output diff --git a/.gitignore b/.gitignore index c0a6c7cb1..ddc0f514c 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,17 @@ var/ .idea/ # Dev -venv \ No newline at end of file +venv + +# AMD +scripts +csrc/flash_attn_ck +.eggs +*.log +core.* +gpucore.* +*.csv +*.png +*.html +*.json +*.txt diff --git a/README.md b/README.md index 054af18c4..db6d2ff82 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports: 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. ### AMD ROCm Support -ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2. +ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2. **Requirements:** - ROCm 6.0 and above. @@ -121,11 +121,72 @@ We recommend the [Pytorch](https://hub.docker.com/r/rocm/pytorch) container from ROCm, which has all the required tools to install FlashAttention. -FlashAttention-2 with ROCm currently supports: +#### Composable Kernel Backend +FlashAttention-2 ROCm CK backend currently supports: 1. MI200 or MI300 GPUs. 2. Datatype fp16 and bf16 3. Forward's head dimensions up to 256. Backward head dimensions up to 128. +#### Triton Backend +The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. + +It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. + +These features are supported in Fwd and Bwd +1) Fwd and Bwd with causal masking +2) Variable sequence lengths +3) Arbitrary Q and KV sequence lengths +4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings + +These features are supported in Fwd for now. We will add them to backward soon. +2) ALiBi and matrix bias + +These features are in development +1) Paged Attention +2) Sliding Window +5) Performance Improvements + +##### Getting Started +To get started with the triton backend for AMD, follow the steps below. + +First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4). + +``` +git clone https://github.com/triton-lang/triton +cd triton +git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 +pip install --verbose -e python +``` +Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. + +``` +export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" +cd flash-attention +python setup.py install +pytest tests/test_flash_attn_triton_amd.py +``` + +###### Docker +We have also created a Dockerfile. + +To build the docker file +``` +cd flash_attn/flash_attn_triton_amd +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` +Inside the docker, it should open to the flash attention repo with everything installed. You can run the following command to test things. +``` +pytest tests/test_flash_attn_triton_amd.py +``` + ## How to use FlashAttention diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 91fe2a918..de6bd9aa5 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -4,10 +4,15 @@ import torch import torch.nn as nn +import os # isort: off # We need to import the CUDA kernels after importing torch -import flash_attn_2_cuda as flash_attn_cuda +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + from .flash_attn_triton_amd import interface_fa as flash_attn_gpu +else: + import flash_attn_2_cuda as flash_attn_gpu # isort: on @@ -85,10 +90,14 @@ def _flash_attn_forward( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_p: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( + out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( q, k, v, @@ -102,6 +111,10 @@ def _flash_attn_forward( softcap, return_softmax, None, + descale_q, + descale_k, + descale_v, + descale_p ) return out, softmax_lse, S_dmask, rng_state @@ -159,9 +172,13 @@ def _flash_attn_varlen_forward( block_table: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_p: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( + out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( q, k, v, @@ -183,6 +200,10 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + descale_q, + descale_k, + descale_v, + descale_p ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -260,7 +281,7 @@ def _flash_attn_backward( dk, dv, softmax_d, - ) = flash_attn_cuda.bwd( + ) = flash_attn_gpu.bwd( dout, q, k, @@ -356,7 +377,7 @@ def _flash_attn_varlen_backward( dk, dv, softmax_d, - ) = flash_attn_cuda.varlen_bwd( + ) = flash_attn_gpu.varlen_bwd( dout, q, k, @@ -799,6 +820,10 @@ def forward( alibi_slopes, deterministic, return_softmax, + descale_q, + descale_k, + descale_v, + descale_p ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -819,6 +844,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p @@ -862,7 +891,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -885,6 +914,10 @@ def forward( deterministic, return_softmax, block_table, + descale_q, + descale_k, + descale_v, + descale_p ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -910,6 +943,10 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -961,7 +998,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1111,6 +1148,10 @@ def flash_attn_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -1172,6 +1213,10 @@ def flash_attn_func( alibi_slopes, deterministic, return_attn_probs, + descale_q, + descale_k, + descale_v, + descale_p ) @@ -1348,6 +1393,10 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1421,6 +1470,10 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, + descale_q, + descale_k, + descale_v, + descale_p ) @@ -1544,7 +1597,7 @@ def flash_attn_with_kvcache( cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_cuda.fwd_kvcache( + out, softmax_lse = flash_attn_gpu.fwd_kvcache( q, k_cache, v_cache, diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile new file mode 100644 index 000000000..91e4c4734 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/Dockerfile @@ -0,0 +1,23 @@ +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 + +WORKDIR /workspace + +# install triton +RUN git clone https://github.com/triton-lang/triton &&\ + cd triton &&\ + git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 &&\ + pip uninstall -y triton &&\ + cd python &&\ + pip install matplotlib pandas pytest einops &&\ + pip install --verbose -e . + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md new file mode 100644 index 000000000..560ae6bac --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -0,0 +1,67 @@ +Flash Attention Triton Kernel +=============== + +#### Introduction +The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. + +It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. + +These features are supported in Fwd and Bwd +1) Fwd and Bwd with causal masking +2) Variable sequence lengths +3) Arbitrary Q and KV sequence lengths +4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings + +These features are supported in Fwd for now. We will add them to backward soon. +2) ALiBi and matrix bias + +These features are in development +1) Paged Attention +2) Sliding Window +5) Performance Improvements + +##### Getting Started +To get started with the triton backend for AMD, follow the steps below. + +First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4). + +``` +git clone https://github.com/triton-lang/triton +cd triton +git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 +pip install --verbose -e python +``` +Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. + +``` +export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" +cd flash-attention +python setup.py install +pytest tests/test_flash_attn_triton_amd.py +``` + +###### Docker +We have also created a Dockerfile. + +To build the docker file +``` +cd flash_attn/flash_attn_triton_amd +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` +Inside the docker, it should open to the flash attention repo with everything installed. You can run the following command to test things. +``` +pytest tests/test_flash_attn_triton_amd.py +``` + +##### Credits +AMD Triton kernels team + +OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/__init__.py b/flash_attn/flash_attn_triton_amd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py new file mode 100644 index 000000000..0d1361274 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -0,0 +1,292 @@ +import argparse +import torch +import triton +from flash_attn.flash_attn_triton_amd.utils import ( + MetaData, + input_helper, + varlen_input_helper, +) +from flash_attn.flash_attn_triton_amd.interface_torch import attention_prefill, attention_decode + +ARGS_TO_TORCH_DTYPE = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + +FUNCTIONS = { + "prefill": attention_prefill, + "decode": attention_decode +} + +def get_benchmark_configs(args, varlen=False): + """ + Returns benchmark configurations based on whether variable-length sequences are used. + """ + if args.custom_config: + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk + return [(args.b, args.hq, hk, args.sq, sk)] + elif varlen: + return [ + (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] + else: + return [ + (16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (1, 8, 8, 8192, 8192), + (1, 2, 2, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (1, 8, 8, 4096, 8192), + (1, 8, 8, 8192, 4096), + (2, 4, 4, 16384, 8192), + (2, 8, 8, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 8, 8, 3996, 9639), + (2, 8, 8, 8181, 1021), + ] + +def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal): + flops_per_matmul = 0 + + if fn_name.startswith("prefill"): + if layout == "thd": + q, k, v, input_metadata = varlen_input_helper( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device) + for i in range(input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device + ) + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + + if causal: + input_metadata.need_causal() + + o = torch.empty_like(q) + input_data = (q, k, v, o, input_metadata) + elif fn_name.startswith("decode"): + q = torch.randn( + [BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ) + k = torch.randn( + [BATCH, N_CTX_K, HK, 1, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ).expand(-1, -1, -1, HQ // HK, -1) + v = torch.randn( + [BATCH, N_CTX_K, HK, 1, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ).expand(-1, -1, -1, HQ // HK, -1) + input_metadata = MetaData(sm_scale=1.3) + input_metadata.layout = "bsghd" + + # Adjust flops calculation if needed + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + + input_data = (q, k, v, input_metadata) + else: + raise ValueError("Unsupported benchmark function") + + input_metadata.use_exp2 = True + return input_data, flops_per_matmul + +def run_benchmark(args, fn_name, fn, mode): + """ + Runs the benchmark for the provided function based on the provided arguments. + """ + print(f"Benchmarking {fn_name} in {mode} mode...") + + dtype = ARGS_TO_TORCH_DTYPE[args.dtype] + head_size = args.d if args.d else 128 + causal = args.causal + varlen = args.layout == "thd" + return_tflops = args.return_tflops + line_names = "TFLOPS" if return_tflops else "Time (ms)" + + # Determine configurations + x_vals_list = get_benchmark_configs(args, varlen=varlen) + + # Setup benchmark configurations + configs = [ + triton.testing.Benchmark( + x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"], + x_vals=x_vals_list, + line_arg="provider", + line_vals=["triton"], + line_names=[line_names], + styles=[("red", "-")], + ylabel="ms", + plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}", + args={ + "D_HEAD": head_size, + "dtype": dtype, + "causal": causal, + "mode": mode, + }, + ) + ] + + @triton.testing.perf_report(configs) + def bench_function( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" + ): + warmup = 25 + rep = 100 + flops_per_matmul = 0 + + # generate function inputs + fn_inputs, flops_per_matmul = gen_fn_inputs( + fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal + ) + + # define the function to benchmark + if mode == "fwd": + benchmark_fn = lambda: fn(*fn_inputs) + total_flops = 2 * flops_per_matmul + elif mode == "bwd": + outputs = fn(*fn_inputs) + output = outputs[0] + grad_output = torch.randn_like(output) + benchmark_fn = lambda: output.backward(grad_output, retain_graph=True) + total_flops = 2 * flops_per_matmul * 2.5 + else: + raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.") + + if causal: + total_flops *= 0.5 + + # Run the benchmark + ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep) + + if return_tflops: + return total_flops / ms * 1e-9 + else: + return ms + + bench_function.run(save_path=".", print_data=True) + +def supported_layouts(): + """ + Returns a string describing the supported layouts. + """ + return ( + "bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n" + "bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n" + "thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n" + 'This layout is sometimes called "varlen" or "grouped" layout.' + ) + +def parse_args(): + """ + Parses command-line arguments. + """ + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument( + "-equal_seqlens", + action="store_true", + default=False, + help="If specified, each context within the thd layout has same seqlen as sq and sk", + ) + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action="store_true", default=False) + parser.add_argument("-dtype", default="fp16") + parser.add_argument("-return_tflops", action="store_true", default=False) + parser.add_argument( + "-layout", + type=str, + default="bhsd", + help=supported_layouts(), + ) + parser.add_argument( + "-benchmark_fn", + type=str, + nargs="*", + choices=FUNCTIONS.keys(), + help="Function(s) to benchmark: prefill, decode, or both", + ) + parser.add_argument( + "-mode", + type=str, + nargs='*', + default=["fwd", "bwd"], + choices=["fwd", "bwd"], + help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass", + ) + return parser.parse_args() + +def main(): + """ + Main function to run benchmarks. + """ + args = parse_args() + + # Validate arguments + assert ( + args.layout == "thd" or not args.equal_seqlens + ), "Equal sequence lengths arg must be used with the thd layout." + args.custom_config = False + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + args.custom_config = True + assert args.b and args.hq and args.sq and args.d, ( + "If custom config is specified, please provide all of batch, " + "number of Q heads, Q sequence length, and head size." + ) + assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported." + + # determine the functions to benchmark + if args.benchmark_fn is None or len(args.benchmark_fn) == 0: + bench_fn_list = FUNCTIONS.keys() + else: + bench_fn_list = args.benchmark_fn + + # benchmark functions + for fn_name in bench_fn_list: + if fn_name not in FUNCTIONS: + raise ValueError(f"Invalid benchmark function specified: {fn_name}") + for mode in args.mode: + if fn_name == "decode" and mode == "bwd": + print(f"Decode kernel doesnot have a backward pass") + continue + run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode) + +if __name__ == "__main__": + main() diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py new file mode 100644 index 000000000..a30fe3fe5 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -0,0 +1,762 @@ +import torch +import triton +import triton.language as tl +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shape_from_layout, get_strides_from_layout, write_dropout_mask, create_dropout_mask + +# TODO: move this into utils.py so it's shared among kernels +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH +tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP + +@triton.jit +def _bwd_preprocess_use_o( + Out, + DO, + Delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_deltaz, stride_deltah, stride_deltam, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + N_CTX_Q: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + IS_VARLEN: tl.constexpr +): + pid_bh = tl.program_id(0) + pid_m = tl.program_id(1) + + # Compute batch and head indices + off_z = pid_bh // H + off_h = pid_bh % H + + if IS_VARLEN: + # Compute sequence lengths for the current batch + q_start = tl.load(cu_seqlens_q + off_z) + q_end = tl.load(cu_seqlens_q + off_z + 1) + k_start = tl.load(cu_seqlens_k + off_z) + k_end = tl.load(cu_seqlens_k + off_z + 1) + + # Compute actual sequence lengths + N_CTX_Q = q_end - q_start + N_CTX_K = k_end - k_start + else: + q_start = 0 + k_start = 0 + N_CTX_Q = max_seqlen_q + N_CTX_K = max_seqlen_k + + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_d = tl.arange(0, BLOCK_DMODEL) + + # create masks + mask_m = off_m < N_CTX_Q + mask_d = off_d < ACTUAL_BLOCK_DMODEL + + # compute offsets + o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om + do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om + + # compute pointers + out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok + do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok + + # load + o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + + # compute delta + delta = tl.sum(o * do, axis=1) + + # write-back delta + delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + delta_ptrs = delta_offset + off_m * stride_deltam + tl.store(delta_ptrs, delta, mask=mask_m) + + +@triton.jit +def _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + l_offset, + delta_offset, + dropout_offset, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, + N_CTX_Q, + N_CTX_K, + start_n, + num_block_m, + num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + if CAUSAL: + # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M + lo = 0 + else: + lo = 0 + + # initialize col and head offsets + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # masks + mask_n = offs_n < N_CTX_K + mask_d = offs_d < ACTUAL_BLOCK_DMODEL + kv_mask = mask_n[:, None] & mask_d[None, :] + + + # initialize grad accumulators + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # load k and v once per column block + k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk + k = tl.load(k_ptrs, mask=kv_mask, other=0.0) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + + # loop over rows + for start_m in range(lo, num_block_m): + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + + # update mask as row block changes + mask_m = offs_m < N_CTX_Q + q_mask = mask_m[:, None] & mask_d[None, :] + + # load q, k, v, do on-chip + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + do = tl.load(do_ptrs, mask=q_mask, other=0.0) + + # recompute p = softmax(qk, dim=-1).T + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + if CAUSAL: + col_offset = N_CTX_Q - N_CTX_K + causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :]) + qk = tl.where(causal_mask, qk, float("-inf")) + + l_ptrs = l_offset + offs_m * stride_deltam + l_i = tl.load(l_ptrs, mask=mask_m) + + # compute p + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + qk *= sm_scale * RCP_LN2 + l_i *= RCP_LN2 + p = tl.math.exp2(qk - l_i[:, None]) + else: + qk *= sm_scale + p = tl.math.exp(qk - l_i[:, None]) + + # mask block in the cases where the data is smaller the block size + p_mask = mask_m[:, None] & mask_n[None, :] + p = tl.where(p_mask, p, 0.0) + + if DROPOUT: + # NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing + philox_offset = batch_philox_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + # print("philox_seed:", philox_seed) + # print("philox_offset:", philox_offset) + if tl_DROPOUT_USE_PYTORCH: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load(dropout_ptrs, mask=p_mask) + else: + rand_vals = tl.rand(philox_seed, philox_offset) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1/ (1 - dropout_p) + + if tl_DROPOUT_DUMP: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + tl.store(dropout_ptrs, dropout_mask, mask=p_mask) + + # apply dropout mask + p_drop = tl.where(dropout_mask, p, 0.0) + p_drop_scaled = p_drop * dropout_scale + p_drop_scaled = p_drop_scaled.to(tl.float16) + + # compute dv + dv += tl.dot(tl.trans(p_drop_scaled), do) + + # compute dp + dp_drop_scaled = tl.dot(do, tl.trans(v)) + dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale + + # compute ds + delta_ptrs = delta_offset + offs_m * stride_deltam + delta_i = tl.load(delta_ptrs, mask=mask_m) + dscores_scaled = (p * (dp - delta_i[:, None])) + ds = dscores_scaled * sm_scale + ds = tl.where(p_mask, ds, 0.0) + ds = ds.to(tl.float16) + else: + p = p.to(tl.float16) + + # compute dv + dv += tl.dot(tl.trans(p), do) + + # compute dp + dp = tl.dot(do, tl.trans(v)) + + # compute ds + delta_ptrs = delta_offset + offs_m * stride_deltam + delta_i = tl.load(delta_ptrs, mask=mask_m) + dscores_scaled = (p * (dp - delta_i[:, None])) + ds = dscores_scaled * sm_scale + ds = tl.where(p_mask, ds, 0.0) + ds = ds.to(tl.float16) + + # compute dk + dk += tl.dot(tl.trans(ds), q) + + # compute dq + if SEQUENCE_PARALLEL: + dq = tl.dot(ds, k) + else: + dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) + + # write-back dv and dk + dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk + + # write-back + if GROUP_SIZE != 1: + # use atomic_add to properly accumulate gradients from multiple query heads + tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + else: + tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + +@triton.jit +def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + Delta, + Dropout_mask, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, + Z, + HQ, + HK, + num_block_m, + num_block_n, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # program ids + off_zh = tl.program_id(0) + if SEQUENCE_PARALLEL: + start_n = tl.program_id(1) + off_z = off_zh // HQ + off_hq = off_zh % HQ + + GROUP_SIZE = HQ // HK + if GROUP_SIZE != 1: + off_hk = off_hq // GROUP_SIZE + else: + off_hk = off_hq + + if IS_VARLEN: + # Compute sequence lengths for the current batch + q_start = tl.load(cu_seqlens_q + off_z) + q_end = tl.load(cu_seqlens_q + off_z + 1) + k_start = tl.load(cu_seqlens_k + off_z) + k_end = tl.load(cu_seqlens_k + off_z + 1) + + # Compute actual sequence lengths + N_CTX_Q = q_end - q_start + N_CTX_K = k_end - k_start + else: + q_start = 0 + k_start = 0 + N_CTX_Q = max_seqlen_q + N_CTX_K = max_seqlen_k + + # input tensor offsets + q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn + do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + + if DROPOUT: + batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + else: + batch_philox_offset = 0 + dropout_offset = 0 + + + # output tensor offsets + dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn + if SEQUENCE_PARALLEL: + dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + else: + dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + + # inner loop + if SEQUENCE_PARALLEL: + _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + Delta, + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + l_offset, + delta_offset, + dropout_offset, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, + N_CTX_Q, + N_CTX_K, + start_n, + num_block_m, + num_block_n, + dropout_p, philox_seed, batch_philox_offset, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + CAUSAL=CAUSAL, + DROPOUT=DROPOUT, + USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE + ) + else: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + Delta, + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + l_offset, + delta_offset, + dropout_offset, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, + N_CTX_Q, + N_CTX_K, + start_n, + num_block_m, + num_block_n, + dropout_p, philox_seed, batch_philox_offset, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + CAUSAL=CAUSAL, + DROPOUT=DROPOUT, + USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE + ) + + +# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. +def attention_prefill_backward_triton_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p, + philox_seed, + philox_offset, + use_exp2: bool, + sequence_parallel = True, +): + if DEBUG: + print() + print("attention_prefill_backward_triton_impl") + print("do:", do, do.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("sm_scale:", sm_scale) + print("alibi_slopes:", alibi_slopes) + print("causal:", causal) + print("layout:", layout) + print("cu_seqlens_q:", cu_seqlens_q) + print("cu_seqlens_k:", cu_seqlens_k) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("use_exp2:", use_exp2) + print("sequence_parallel:", sequence_parallel) + + # make contigious + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + softmax_lse = softmax_lse.contiguous() + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + stride_qz, stride_qh, stride_qm, stride_qk = q_strides + stride_kz, stride_kh, stride_kn, stride_kk = k_strides + stride_vz, stride_vh, stride_vn, stride_vk = v_strides + stride_oz, stride_oh, stride_om, stride_ok = o_strides + is_varlen = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks + if max_seqlen_q <= 32 or max_seqlen_k <= 32: + BLOCK_M = 32 + BLOCK_N = 32 + else: + BLOCK_M = 64 + BLOCK_N = 64 + if DEBUG: + print("BLOCK_M:", BLOCK_M) + print("BLOCK_N:", BLOCK_N) + + num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful + num_stages = 1 + waves_per_eu = 1 + + # divide up the problem + num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M) + num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + BLOCK_DMODEL = padded_d_model + ACTUAL_BLOCK_DMODEL = head_size + + do = do.contiguous() + + # deal with dq + if dq is None: + if sequence_parallel: + dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype) + stride_dq_all = dq.stride()[0] + + # deal with dk, dv + if (dk is None) or (dv is None): + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + + # zero out + dq.zero_() + dk.zero_() + dv.zero_() + + # assert contigious + assert do.is_contiguous() + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert o.is_contiguous() + assert softmax_lse.is_contiguous() + + # init delta + delta = torch.empty_like(softmax_lse) + if is_varlen: + stride_deltam, stride_deltah = delta.stride() + stride_deltaz = 0 + else: + stride_deltaz, stride_deltah, stride_deltam = delta.stride() + + # dropout mask tensor for debugging. We dump the dropout mask created in the kernel for testing + if use_dropout: + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, + dtype=torch.float32) + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (dropout_mask.stride(0), dropout_mask.stride(1), dropout_mask.stride(2), dropout_mask.stride(3)) + else: + dropout_mask = None + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0 , 0 , 0) + + + _bwd_preprocess_use_o[(batch * nheads_q, num_blocks_m)]( + o, + do, + delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_oz, stride_oh, stride_om, stride_ok, + stride_deltaz, stride_deltah, stride_deltam, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + N_CTX_Q=max_seqlen_q, + Z=batch, + H=nheads_q, + IS_VARLEN=is_varlen + ) + + if False: + print("_bwd_kernel inputs") + print("do:", do, do.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("sm_scale", sm_scale) + print("o:", o, o.shape) + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("L:", softmax_lse, softmax_lse.shape) + print("delta:", delta, delta.shape) + print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk) + print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk) + print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk) + print("batch_q:", batch) + print("heads_q:",nheads_q) + print("max_seqlen_q:",max_seqlen_q) + print("max_seqlen_k:",max_seqlen_k) + print("dropout_p:",dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:",philox_offset) + print("BLOCK_M:",BLOCK_M) + print("BLOCK_N:",BLOCK_M) + print("BLOCK_DMODEL:",BLOCK_DMODEL) + print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL) + print("SEQUENCE_PARALLEL:",sequence_parallel) + print("CAUSAL:",causal) + print("DROPOUT:", use_dropout) + print("num_warps:",num_warps) + print("num_stages:", num_stages) + print("USE_EXP2:", use_exp2) + print("num_blocks_m:", num_blocks_m) + print("num_blocks_n:", num_blocks_n) + + _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( + q, + k, + v, + sm_scale, + o, + do, + dq, + dk, + dv, + softmax_lse, + delta, + dropout_mask, + stride_dq_all, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_deltaz, stride_deltah, stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, + batch, + nheads_q, + nheads_k, + num_blocks_m, + num_blocks_n, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, philox_seed, philox_offset, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + SEQUENCE_PARALLEL=sequence_parallel, + CAUSAL=causal, + DROPOUT=use_dropout, + USE_EXP2=use_exp2, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu = waves_per_eu, + IS_VARLEN=is_varlen + ) + + if sequence_parallel: + dq = dq.sum(dim=0) + + if DEBUG: + print("attention_prefill_backward_triton_impl outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + print("copy_back:", copy_back) + if use_dropout: + print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) + print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) + write_dropout_mask(dropout_mask, "dropout_mask_bwd") + + return dq, dk, dv, delta, None, None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py new file mode 100644 index 000000000..41cdfbb73 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -0,0 +1,1099 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shape_from_layout, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH +tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + cu_seqlens_q, max_seqlen_q, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_do + do_ptrs = DO + offs_do + # load + o = tl.load(out_ptrs, mask=mask_md, other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32) + # compute and write-back to delta + delta = tl.sum(o * do, axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_EXP2: tl.constexpr, # activate dropout + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_nm + ) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + qkT = tl.dot(k, qT) + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + pT_dropout = pT_dropout.to(tl.float16) + dv += tl.dot(pT_dropout, do) + else: + pT = pT.to(tl.float16) + dv += tl.dot(pT, do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) +@triton.jit +def _bwd_kernel_dkdv( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 + # align the delta_qk + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + GROUP_SIZE: tl.constexpr = HQ // HK + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + \ + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V. + stride_qm, stride_qk, stride_kn, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_qk + vT_ptrs = V + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_qk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + qk = tl.dot(q, kT) + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk * sm_scale - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_kn + return dq + + +# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +@triton.jit +def _bwd_kernel_dq( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = pid * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 + if start_m + BLOCK_M < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, MASK_BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + GROUP_SIZE: tl.constexpr = HQ // HK + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p, + philox_seed, + philox_offset, + use_exp2: bool, + DEBUG_TRITON: bool = False, + DEBUG_TRITON_DETAIL: bool = False, +): + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + dq.zero_() + dk.zero_() + dv.zero_() + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = \ + get_shape_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, _, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, _, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + cu_seqlens_q, max_seqlen_q, + BLOCK_M=PRE_BLOCK, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN + ) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 + _bwd_kernel_dkdv[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 + _bwd_kernel_dq[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON and False, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL and False, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON and False, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL and False, + ) + + return dq, dk, dv, delta, None, None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split_experiment.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split_experiment.py new file mode 100644 index 000000000..f3e6410e0 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split_experiment.py @@ -0,0 +1,1067 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shape_from_layout, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH +tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + cu_seqlens_q, max_seqlen_q, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_o = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_o + do_ptrs = DO + offs_o + # load + # o = tl.load(out_ptrs, mask=mask_md, other=0.0).to(tl.float32) + # do = tl.load(do_ptrs, mask=mask_md, other=0.0).to(tl.float32) + o = tl.load(out_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + # compute and write-back to delta + delta = tl.sum(o * do, axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + # tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + tl.store(delta_offset + offs_m * stride_deltam, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, # shared by Q/DO. + stride_dropoutm, stride_dropoutn, # + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_EXP2: tl.constexpr, # activate dropout + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + # mask_m = offs_m < seqlen_q + # mask_qT = mask_m[None, :] + # mask_do = mask_m[:, None] + # mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + # if PADDED_HEAD: + # mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + # mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + # qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + qT = tl.load(qT_ptrs) + # generate dropout mask + # if ENABLE_DROPOUT: + # # NOTE: dropout is transposed because it is used to mask pT + # philox_offs = curr_philox_offset + \ + # offs_m[None, :] * stride_dropoutm + \ + # offs_n[:, None] * stride_dropoutn + # if tl_DROPOUT_USE_PYTORCH: + # dropout_offs = offs_m[None, :] * stride_dropoutm + \ + # offs_n[:, None] * stride_dropoutn + # dropout_mask = tl.load( + # curr_dropout_offset + dropout_offs, + # mask=mask_nm + # ) + # else: + # rand_vals = tl.rand(philox_seed, philox_offs) + # dropout_mask = rand_vals > dropout_p + # dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + # m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + m = tl.load(M + offs_m * stride_deltam) + qkT = tl.dot(k, qT) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + # mask = causal_mask & mask_nm + mask = causal_mask + pT = tl.where(mask, pT, 0.0) + # do = tl.load(do_ptrs, mask=mask_do, other=0.0) + do = tl.load(do_ptrs) + # Compute dV. + # if ENABLE_DROPOUT: + # pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + # pT_dropout = pT_dropout.to(tl.float16) + # dv += tl.dot(pT_dropout, do) + # else: + pT = pT.to(tl.float16) + dv += tl.dot(pT, do) + + # D (= delta) is pre-divided by ds_scale. + # Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + Di = tl.load(D + offs_m * stride_deltam) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + # if ENABLE_DROPOUT: + # dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_qm + return dk, dv + + +# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) +@triton.jit +def _bwd_kernel_dkdv( + Q, K, V, sm_scale, Out, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + # align the delta_qk + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + # mask_kv = offs_n[:, None] < seqlen_k + # PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + # if PADDED_HEAD: + # mask_k = offs_k < ACTUAL_HEAD_DIM + # mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + # k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + k = tl.load(K + adj_kv + offs_kv) + v = tl.load(V + adj_kv + offs_kv) + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + start_m_unmasked = start_m + num_steps * MASK_BLOCK_M + num_steps_unmasked = tl.cdiv(seqlen_q - start_m_unmasked, BLOCK_M) + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + hqid = hkid + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + DO_ptr = DO + adj_q + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m_unmasked, num_steps_unmasked, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + # tl.store(DV + adj_kv + offs_kv, dv, mask=mask_kv) + # dk *= sm_scale + # tl.store(DK + adj_kv + offs_kv, dk, mask=mask_kv) + tl.store(DV + adj_kv + offs_kv, dv) + dk *= sm_scale + tl.store(DK + adj_kv + offs_kv, dk) + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V/DO. + stride_qm, stride_qk, stride_kn, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + # PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + # mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_qk + vT_ptrs = V + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_qk + # D (= delta) is pre-divided by ds_scale. + # Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + Di = tl.load(Delta + offs_m * stride_deltam) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + # mask_kT = mask_n[None, :] + # mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + # if PADDED_HEAD: + # mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + # kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + # vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + + # if ENABLE_DROPOUT: + # # NOTE: dropout is transposed because it is used to mask pT + # philox_offs = curr_philox_offset + \ + # offs_m[:, None] * stride_dropoutm + \ + # offs_n[None, :] * stride_dropoutn + # if tl_DROPOUT_USE_PYTORCH: + # dropout_offs = offs_m[:, None] * stride_dropoutm + \ + # offs_n[None, :] * stride_dropoutn + # dropout_mask = tl.load( + # curr_dropout_offset + dropout_offs, + # mask=mask_mn) + # else: + # rand_vals = tl.rand(philox_seed, philox_offs) + # dropout_mask = rand_vals > dropout_p + # dropout_scale = 1 / (1 - dropout_p) + + qk = tl.dot(q, kT) + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk * sm_scale - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + # mask = causal_mask & mask_mn + mask = causal_mask + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + # if ENABLE_DROPOUT: + # dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_kn + return dq + + +# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +@triton.jit +def _bwd_kernel_dq( + Q, K, V, sm_scale, Out, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = pid * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 + if start_m + BLOCK_M < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + hqid = hkid + + # for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + # q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + # do = tl.load(DO + adj_q + offs_q, mask=mask_q, other=0.0) + # m = tl.load(M + adj_delta + offs_m * stride_deltam, + # mask=offs_m < seqlen_q) + q = tl.load(Q + adj_q + offs_q) + do = tl.load(DO + adj_q + offs_q) + m = tl.load(M + adj_delta + offs_m * stride_deltam) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, MASK_BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + offs_dq = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + dq *= sm_scale + # tl.store(DQ + adj_q + offs_dq, dq, mask=mask_q) + tl.store(DQ + adj_q + offs_dq, dq) + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, Out, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + GROUP_SIZE = HQ // HK + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + DO_ptr = DO + adj_q + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + tl.store(DV + adj_kv + offs_kv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_kv + offs_kv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, Out, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_q + offs_q, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M, BLOCK_N, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + offs_dq = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + dq *= sm_scale + tl.store(DQ + adj_q + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p, + philox_seed, + philox_offset, + use_exp2: bool, + DEBUG_TRITON: bool = False, + DEBUG_TRITON_DETAIL: bool = False, +): + # make contigious + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + softmax_lse = softmax_lse.contiguous() # (batch, head_q, seqlen_q) + do = do.contiguous() + + if dq is None: + dq = torch.zeros_like(q) + if dk is None: + dk = torch.zeros_like(k) + if dv is None: + dv = torch.zeros_like(v) + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = \ + get_shape_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_vb, stride_vh, stride_vn, stride_vk = v_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + cu_seqlens_q, max_seqlen_q, + BLOCK_M=PRE_BLOCK, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN + ) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 + _bwd_kernel_dkdv[grid_dkdv]( + q, k, v, sm_scale, o, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 + _bwd_kernel_dq[grid_dq]( + q, k, v, sm_scale, o, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON and False, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL and False, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, o, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, o, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_deltab, stride_deltah, stride_deltam, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON and False, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL and False, + ) + + return dq, dk, dv, delta, None, None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split_oneKernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split_oneKernel.py new file mode 100644 index 000000000..20d8bf3d1 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split_oneKernel.py @@ -0,0 +1,665 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shape_from_layout, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen +from .bwd_prefill_split import _bwd_preprocess, _bwd_dkdv_inner, _bwd_dq_inner + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH +tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP + + +# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +@triton.jit +def bwd_kernel( + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + \ + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}") # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, MASK_BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + + +@triton.jit +def bwd_kernel_noncausal( + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_oneKernel_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p, + philox_seed, + philox_offset, + use_exp2: bool, + DEBUG_TRITON: bool = False, + DEBUG_TRITON_DETAIL: bool = False, +): + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + dq.zero_() + dk.zero_() + dv.zero_() + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = \ + get_shape_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, _, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, _, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + cu_seqlens_q, max_seqlen_q, + BLOCK_M=PRE_BLOCK, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN + ) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + assert BLOCK_N1 == BLOCK_M2 + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = ((seqlen + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 + bwd_kernel[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_noncausal[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return dq, dk, dv, delta, None, None diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py new file mode 100644 index 000000000..23c272334 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -0,0 +1,481 @@ +import torch +import math +from .utils import DEBUG + +DEBUG_CORE = False + +def attention_backward_core_ref_impl( + do, q, k, v, o, softmax_lse, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2 +): + if DEBUG_CORE: + print() + print("attention_backward_core_ref_impl") + print("do:", do, do.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) # is a bad number + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sm_scale:", sm_scale) + print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("use_exp2:", use_exp2) + + # cast to float32 + do = do.to(torch.float32) + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) + o = o.to(torch.float32) + softmax_lse = softmax_lse.to(torch.float32) + + + # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: + print("attention_scores:", attention_scores, attention_scores.shape) + + # scale scores + attention_scaled_scores = sm_scale * attention_scores + if DEBUG_CORE: + print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + + # Apply causal mask if necessary + if causal: + L_q, L_k = q.shape[1], k.shape[1] + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + col_offset = L_q-L_k + causal_mask = row_idx >= (col_offset + col_idx) + if DEBUG_CORE: + print("causal_mask:", causal_mask) + # set -inf to places the causal mask is false + attention_scaled_scores = attention_scaled_scores.masked_fill( + torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + ) + if DEBUG_CORE: + print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) + + # compute probabilities using softmax_lse + if use_exp2: + RCP_LN = 1 / math.log(2) + attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN + softmax_lse_base2 = softmax_lse * RCP_LN + softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) + p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d) + else: + softmax_lse_3d = softmax_lse.unsqueeze(-1) + p = torch.exp(attention_scaled_scores - softmax_lse_3d) + if DEBUG_CORE: + print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) + print("p:", p, p.shape) + + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + + p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) + p_drop_scaled = p_drop * dropout_scale + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("p_drop:", p_drop, p_drop.shape) + print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape) + + # compute dv + dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp_dropout = torch.matmul(do, v.transpose(-2, -1)) + dp = torch.where(dropout_mask, dp_dropout , torch.zeros_like(dp_dropout)) * dropout_scale + if DEBUG_CORE: + print("dp_dropout:", dp_dropout, dp_dropout.shape) + print("dp:", dp, dp.shape) + + # calculate ds + if True: + delta = torch.sum(o * do, axis=-1).unsqueeze(-1) + else: + delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) + dscores_scaled = p * (dp - delta) + ds = dscores_scaled * sm_scale + else: + # compute dv + dv = torch.matmul(p.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp = torch.matmul(do, v.transpose(-2, -1)) + if DEBUG_CORE: + print("dp:", dp, dp.shape) + + # calculate ds + delta = torch.sum(o * do, axis=-1).unsqueeze(-1) + dscores_scaled = p * (dp - delta) + ds = dscores_scaled * sm_scale + if DEBUG_CORE: + print("delta:", delta, delta.shape) + print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) + print("ds:", ds, ds.shape) + + # compute gradient wrt k & q + dk = torch.matmul(ds.transpose(-2, -1), q) + dq = torch.matmul(ds, k) + if DEBUG_CORE: + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + + # cast back to original dtype + dq = dq.to(torch.float16) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + # remove d dim with size 1 + delta = delta.squeeze(-1) + + if DEBUG_CORE: + print("attention_backward_core_ref_impl output") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + + return dq, dk, dv, delta + +def attention_varlen_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2, +): + # Ensure the layout is 'thd' + if layout != 'thd': + raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") + + batch_size = cu_seqlens_q.shape[0] - 1 + nheads_q, head_dim = q.shape[1], q.shape[2] + nheads_k = k.shape[1] + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + # Pre-allocate outputs + total_L_q = q.shape[0] + total_L_k = k.shape[0] + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + # delta has the same shape as softmax_lse: [total_L_q, nheads_q] + delta = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=o.device) + + for i in range(batch_size): + # Get the start and end indices for the current sequence + start_q = cu_seqlens_q[i].item() + end_q = cu_seqlens_q[i + 1].item() + start_k = cu_seqlens_k[i].item() + end_k = cu_seqlens_k[i + 1].item() + + # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, nheads_q] + + if group_size != 1: + # MQA or GQA case + # Reshape tensors to include group dimension + q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) + do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) + o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) + softmax_lse_i = softmax_lse_i.view(softmax_lse_i.shape[0], nheads_k, group_size) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) + v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) + # Flatten the nheads_k and group_size dimensions + q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) + do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) + o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) + softmax_lse_i = softmax_lse_i.reshape(softmax_lse_i.shape[0], nheads_k * group_size) + k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) + v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) + # Permute to [nheads_total, L, head_dim] + q_i = q_i.permute(1, 0, 2) + k_i = k_i.permute(1, 0, 2) + v_i = v_i.permute(1, 0, 2) + do_i = do_i.permute(1, 0, 2) + o_i = o_i.permute(1, 0, 2) + softmax_lse_i = softmax_lse_i.transpose(0, 1) + + # Call the core backward function for this sequence + dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( + do_i, + q_i, + k_i, + v_i, + o_i, + softmax_lse_i, + sm_scale, + causal, + dropout_p, + philox_seed, + philox_offset, + use_exp2 + ) + + # Convert back to 'thd' layout + dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] + dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + delta_i = delta_i.transpose(1, 0) # [L_q_i, nheads_total] + + if group_size != 1: + # Reshape dq_i and delta_i back to original shape + dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) + delta_i = delta_i.view(delta_i.shape[0], nheads_k, group_size) + # Sum dk_i and dv_i over group dimension + dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) + dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) + dk_i = dk_i.sum(dim=2) + dv_i = dv_i.sum(dim=2) + # Reshape dq_i back to [L_q_i, nheads_q, head_dim] + dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) + delta_i = delta_i.reshape(delta_i.shape[0], nheads_q) + else: + # No need to reshape + pass + + # Place outputs in pre-allocated tensors + dq[start_q:end_q, :, :] = dq_i + dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys + dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values + delta[start_q:end_q, :] = delta_i + + return dq, dk, dv, delta + +def attention_vanilla_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + dropout_p, + philox_seed, + philox_offset, + use_exp2, +): + if layout == "bshd": + if DEBUG: + print() + print("Changing layout to bhsd!") + do = do.transpose(1, 2).contiguous() + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + o = o.transpose(1, 2).contiguous() + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") + + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape do, q, o to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + do = do.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape softmax_lse to [batch_size, nheads_k, group_size, seq_len_q] + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) # [batch_size, nheads_k, group_size, seq_len_k, head_dim] + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + do = do.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_k * group_size, seq_len_q) + else: + # Standard case + do = do.reshape(batch_size * nheads_q, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_q, seq_len_q) + + # Call the core backward function + dq, dk, dv, delta = attention_backward_core_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + dropout_p, + philox_seed, + philox_offset, + use_exp2 + ) + + if group_size != 1: + # Reshape dq back to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape delta back to [batch_size, nheads_k, group_size, seq_len_q] + delta = delta.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Sum dk and dv over group_size dimension, since k and v are shared across groups + dk = dk.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dk = dk.sum(dim=2) # Sum over group_size dimension + dv = dv.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dv = dv.sum(dim=2) + # Reshape dq to [batch_size, nheads_q, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k * group_size, seq_len_q, head_dim) + delta = delta.reshape(batch_size, nheads_k * group_size, seq_len_q) + else: + # Standard case + dq = dq.reshape(batch_size, nheads_q, seq_len_q, head_dim) + dk = dk.reshape(batch_size, nheads_k, seq_len_k, head_dim) + dv = dv.reshape(batch_size, nheads_k, seq_len_k, head_dim) + delta = delta.reshape(batch_size, nheads_q, seq_len_q) + + # Go back to original layout + if layout == "bshd": + if DEBUG: + print() + print("Changing back to bshd!") + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + dv = dv.transpose(1, 2) + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") + + return dq, dk, dv, delta + +def attention_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2 +): + + if DEBUG: + print() + print("attention_backward_pytorch_ref_impl") + print("do:", do, do.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) + print("softmax_lse:", softmax_lse) + print("sm_scale:", sm_scale) + print("causal:", causal) + print("layout:", layout) + print("cu_seqlens_q:", cu_seqlens_q) + print("cu_seqlens_k:", cu_seqlens_k) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("use_exp2:", use_exp2) + + + if layout == "thd": + dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2, + ) + else: + dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + dropout_p, + philox_seed, + philox_offset, + use_exp2, + ) + + + if DEBUG: + print() + print("attention_backward_pytorch_ref_impl outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + + return dq, dk, dv, delta diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py new file mode 100644 index 000000000..b37308be4 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -0,0 +1,703 @@ +import torch +import triton +import triton.language as tl +from .utils import _strides, get_padded_headsize + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Alibi_slopes, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qd, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kd, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vd, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_d, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_kn_z, + stride_kn_n, + stride_kn_g, + stride_kn_h, + stride_kn_d, + stride_vn_z, + stride_vn_n, + stride_vn_g, + stride_vn_h, + stride_vn_d, + stride_az, + stride_ah, + Z, + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + H_q: tl.constexpr, + H_kv: tl.constexpr, + G_q: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, +): + # Padding + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + if PADDED_HEAD: + d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H_q * G_q) + off_h_q = (off_zhg // G_q) % H_q + off_g_q = off_zhg % G_q + splitk_idx = tl.program_id(2) + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + off_z) + else: + cache_batch_idx = off_z + + # Load ALiBi slope if enabled + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + if NEW_KV: + kv_len = cache_seqlen_last_idx + N_CTX_NEW + else: + kv_len = cache_seqlen_last_idx + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + HEAD_RATIO: tl.constexpr = H_q // H_kv + if IS_GQA: + k_head_idx = off_h_q // HEAD_RATIO + v_head_idx = k_head_idx + else: + k_head_idx = off_h_q + v_head_idx = off_h_q + + # calculate base offset + k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg + v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + + # Copy new Keys and Values into Cache + if NEW_KV: + knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + + # Determine the starting position for new data in the cache + if USE_CACHE_SEQLENs: + start_idx = tl.load(Cache_seqlens + off_z) + else: + start_idx = N_CTX_K - N_CTX_NEW + + # Copy new Keys + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to K + tl.store( + k_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, + k_new_block, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + ) + + # Copy new Values + vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from V_new + v_new_block = tl.load( + vnew_base + + (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to V + tl.store( + v_base + + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, + v_new_block, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + ) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, + shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qd), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + K_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(ACTUAL_BLOCK_DMODEL, hi), + strides=(stride_kd, stride_kn), + offsets=(0, lo), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, ACTUAL_BLOCK_DMODEL), + strides=(stride_vn, stride_vd), + offsets=(lo, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + if PADDED_HEAD: + q = tl.where(d_mask[None, :], q, 0.0) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + 1, + BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL, + Q.dtype.element_ty, + 0, + ) + if PADDED_HEAD: + k = tl.where(d_mask[:, None], k, 0.0) + v = tl.where(d_mask[None, :], v, 0.0) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + if USE_ALIBI: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # Apply causal mask if IS_CAUSAL is True + if IS_CAUSAL: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - kv_len + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + if IS_CAUSAL: + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_i_new) + # cause of nan because subtracting infs + if IS_CAUSAL: + qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) + else: + qk = qk - m_i_new[:, None] + + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(v.dtype), v) + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, BLOCK_DMODEL), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + + if IS_CAUSAL: + l_m_offset = l_m - g_m + alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) + else: + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + if IS_CAUSAL: + # Avoid division by zero + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + else: + acc_out = tl.sum(acc, axis=0) / g_sum + + # Store output + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + + # Store lse + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + if IS_CAUSAL: + lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + +def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): + # kernel config + BLOCK_M = 16 + BLOCK_N = 64 + SPLIT_K = None + NUM_QUANT_GROUPS = 1 + + # kernels expects "bsghd" + original_layout = layout + if layout == "bshd": + q=q.unsqueeze(2) + k=k.unsqueeze(2) + v=v.unsqueeze(2) + if new_kv: + k_new = k_new.unsqueeze(2) + v_new = v_new.unsqueeze(2) + layout = "bsghd" + elif layout == "bhsd": + q=q.permute(0, 2, 1, 3).unsqueeze(2) + k=k.permute(0, 2, 1, 3).unsqueeze(2) + v=v.permute(0, 2, 1, 3).unsqueeze(2) + if new_kv: + k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) + v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) + layout = "bsghd" + elif layout == "bsghd": + pass + elif layout is None: + raise ValueError("Layout not given") + assert layout == "bsghd" + + # get dims + batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape + _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape + _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape + + assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + + # get padded size + dim_padded = get_padded_headsize(dim_k) + + # Handle MQA/GQA case + if heads_per_group_q > heads_per_group_k: + is_gqa = True + elif heads_per_group_q < heads_per_group_k: + raise ValueError("heads_per_group_q < heads_per_group_k") + else: + is_gqa = False + + assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" + + if SPLIT_K is not None: + split_k = SPLIT_K + else: + # Use heuristics + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens? + + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) + lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) + + num_warps = 1 + split_size = (seqlen_k + split_k - 1) // split_k + use_cache_seqlens = cache_seqlens is not None + + # TODO: enable quantization + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + K_new = k_new, + V_new = v_new, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=cache_batch_idx, + Alibi_slopes=alibi_slopes, + **_strides(q, "qz", "qm", "qg", "qh", "qd"), + **_strides(k, "kz", "kn", "kg", "kh", "kd"), + **_strides(v, "vz", "vn", "vg", "vh", "vd"), + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), + **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), + **_strides(alibi_slopes, "az", "ah"), + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, + N_CTX_NEW=k_new.shape[1] if new_kv else None, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_k, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX=cache_batch_idx is not None, + NEW_KV=new_kv, + IS_GQA=is_gqa, + IS_CAUSAL=causal, + USE_ALIBI=False if alibi_slopes is None else True, + num_warps=num_warps, + num_stages=1, + ) + + out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + _splitK_reduce[grid]( + out_splitk, + metadata, + out, + lse, + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=seqlen_q_ceil, + BLOCK_SIZE=k_block_size, + G=n_group_q, + H=heads_per_group_q, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + IS_CAUSAL=causal, + num_warps=4) + + lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) + if q.ndim == 4: + # BMGHK -> BMHK + assert n_group_q == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if seqlen_k == 0: + out.zero_() + out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() + + # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q + if original_layout == "bshd": + # out=out.transpose(1, 2).contiguous() # this screws up heads and data. + # the data is laid out properly. Just need to reshape dims + out = out.reshape(batch_size, seqlen_q, -1, dim_padded) + + return out.narrow(-1, 0, dim_k), lse diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py new file mode 100644 index 000000000..19ae4b139 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -0,0 +1,717 @@ +import torch +import triton +import triton.language as tl +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, arch_supports_fp8, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH +tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP + +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, descale_p, IS_FP8: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr): + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # compute masks + q_mask = (OFFS_M[:, None] < actual_seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + p_mask = q_mask & k_mask + + # -- compute qk ---- + if IS_FP8 : + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + qk_scaled += bias + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = qk_scaled - m_ij[:, None] + + # Compute scaled QK and softmax probabilities + if USE_EXP2: + p = tl.math.exp2(q_shifted * RCP_LN2) + else: + p = tl.math.exp(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + if tl_DROPOUT_USE_PYTORCH: + dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) + else: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + if tl_DROPOUT_DUMP: + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + tl.store(sd_mask_ptrs, p, mask=p_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = m_i - m_ij + if USE_EXP2: + alpha = tl.math.exp2(m_diff * RCP_LN2) + else: + alpha = tl.math.exp(m_diff) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if IS_FP8: + p *= (1.0/ descale_p) # put p into fp8 range + acc += (tl.dot(p.to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if RETURN_SCORES: + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn + return acc, l_i, m_i + + +def get_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_rdna_autotune_configs() + elif is_cdna(): + return get_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + +autotune_configs, autotune_keys = get_autotune_configs() + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def attn_fwd(Q, K, V, bias, + DESCALE_Q, DESCALE_K, DESCALE_V, DESCALE_P, stride_q_inv_scale_z, stride_kv_inv_scale_z, stride_p_inv_scale_z, + SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, + stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + # print("cu_seqlens_q_start:", cu_seqlens_q_start) + + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = tl.cdiv(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_ptrs = l_offset + offs_m * stride_lse_m + + l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) + + # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) + # lse_mask = mask_m_offsets < causal_start_idx + # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + # print("n_extra_tokens:", n_extra_tokens) + # print("seqlen_k:", seqlen_k) + # print("BLOCK_N:", BLOCK_N) + # return + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if RETURN_SCORES: + sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + else: + sd_mask_ptrs = None + + if ENABLE_DROPOUT: + dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + else: + dropout_mask_ptrs = None + philox_ptrs = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + # Load scale factors if IS_FP8. + if IS_FP8: + descale_q = tl.load(DESCALE_Q + off_z * stride_q_inv_scale_z + off_h_q) + descale_k = tl.load(DESCALE_K + off_z * stride_kv_inv_scale_z + off_h_k) + descale_v = tl.load(DESCALE_V + off_z * stride_kv_inv_scale_z + off_h_k) + descale_p = tl.load(DESCALE_P + off_z * stride_p_inv_scale_z + off_h_q) + else: + descale_q, descale_k, descale_v, descale_p = 1.0, 1.0, 1.0, 1.0 + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, descale_p, IS_FP8, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if RETURN_SCORES: + sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + philox_ptrs += n_full_blocks * BLOCK_N * stride_sn + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, descale_p, IS_FP8, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_ptrs = l_offset + offs_m * stride_lse_m + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + mi_base2 = m_i * RCP_LN2 + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + else: + softmax_lse = m_i + tl.math.log(l_i) + + if IS_CAUSAL: + # zero out nans caused by -infs when doing causal + lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx + softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant + else: + tl.store(l_ptrs, softmax_lse) # the log of the normalization constant + + # write back O + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def attention_prefill_forward_triton_impl( + q, + k, + v, + o, + sm_scale, + alibi_slopes, + causal, + bias, + layout, + # varlen + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + # dropout + dropout_p, + philox_seed, + philox_offset, + # misc + return_softmax, + use_exp2, + # fp8 + descale_q=None, + descale_k=None, + descale_v=None, + descale_p=None): + + if DEBUG: + print() + print("attention_prefill_forward_triton_impl") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) + print("sm_scale:", sm_scale) + print("alibi_slopes:", alibi_slopes) + print("causal:", causal) + print("bias:", bias) + print("layout:", layout) + print("cu_seqlens_q:", cu_seqlens_q) + print("cu_seqlens_k:", cu_seqlens_k) + print("max_seqlens_q:", max_seqlens_q) + print("max_seqlens_k:", max_seqlens_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("return_scores:", return_softmax) + print("use_exp2:", use_exp2) + + is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} + if is_fp8: + if DEBUG: + print("IS_FP8") + + type_max = torch.finfo(q.dtype).max + if layout == "bshd": + batch, _ , nheads_q, dim = q.shape + _, _ , nheads_k, _ = k.shape + elif layout == "bhsd": + batch, nheads_q,_, dim = q.shape + _, nheads_k, _, _ = k.shape + elif layout == "thd": + batch = len(cu_seqlens_q) - 1 + nheads_q = q.size(1) + nheads_k = k.size(1) + else: + raise ValueError("Unsupported layout") + + # Get strides for the kernel + descale_q_stride_z = descale_q.stride(0) + descale_k_stride_z = descale_k.stride(0) + descale_v_stride_z = descale_v.stride(0) + descale_p_stride_z = descale_p.stride(0) + else: + # For non-FP8 types, use dummy values (no scaling needed) + descale_q = descale_k = descale_v = descale_p = 1 + descale_q_stride_z = descale_k_stride_z = descale_v_stride_z = descale_p_stride_z = 0 + + + if DEBUG: + print("is_fp8:", is_fp8) + print("descale_q:", descale_q) + print("descale_k:", descale_k) + print("descale_v:", descale_v) + print("descale_p:", descale_p) + print("descale_q_stride_z:", descale_q_stride_z) + print("descale_k_stride_z:", descale_k_stride_z) + print("descale_v_stride_z:", descale_v_stride_z) + print("descale_p_stride_z:", descale_p_stride_z) + if is_fp8: + print(f"type_max: {type_max}") + + + # check if varlen + is_varlen = layout == "thd" + + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (bias is not None): + assert (bias.numel() < 2**31) + + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + use_dropout = (dropout_p > 0.0) + if use_dropout or return_softmax: + sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlens_q, max_seqlens_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) + else: + sd_mask = None + dropout_mask = None + scores_strides = (0, 0, 0, 0) + + # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) + if is_varlen: + softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) + stride_lse_m, stride_lse_h = softmax_lse.stride() + stride_lse_z = 0 + else: + softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + if bias is not None: + bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), + bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if alibi_slopes is not None: + alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + + attn_fwd[grid](q, k, v, bias, + descale_q, descale_k, descale_v, descale_p, descale_q_stride_z, descale_k_stride_z, descale_p_stride_z, + sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, + HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, + BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, + USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8) + + if DEBUG: + print() + print("attention_prefill_forward_triton_impl outputs") + print("o:", o, o.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + if use_dropout: + print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) + print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) + write_dropout_mask(dropout_mask, "dropout_mask_fwd") + + return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py new file mode 100644 index 000000000..909996654 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -0,0 +1,381 @@ +import torch +import math +from .utils import DEBUG + +DEBUG_CORE = False + +def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2): + if DEBUG_CORE: + print() + print("attention_forward_core_ref_impl") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("sm_scale:", sm_scale) + print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("use_exp2:", use_exp2) + + # cast to float32 + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) + + # Compute attention scores + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: + print("attention_scores:", attention_scores, attention_scores.shape) + + # Scale scores + attention_scaled_scores = sm_scale * attention_scores + if DEBUG_CORE: + print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + + # Apply causal mask if necessary + if causal: + L_q, L_k = q.shape[1], k.shape[1] + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + col_offset = L_q-L_k + causal_mask = row_idx >= (col_offset + col_idx) + if DEBUG_CORE: + print("causal_mask:", causal_mask) + # set -inf to places the causal mask is false + attention_scaled_scores = attention_scaled_scores.masked_fill( + torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + ) + if DEBUG_CORE: + print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) + + # Compute max for numerical stability + max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] + if DEBUG_CORE: + print("max_scores:", max_scores, max_scores.shape) + if causal: + # Replace -inf in max_scores with zeros to avoid NaN in subtraction + max_scores = torch.where( + torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores + ) + if DEBUG: + print("max_scores if causal:", max_scores, max_scores.shape) + + # Shift scores + attention_shifted_scaled_scores = attention_scaled_scores - max_scores + if DEBUG_CORE: + print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape) + + # Exponentiate + if use_exp2: + RCP_LN = 1 / math.log(2) + exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores) + else: + exp_scores = torch.exp(attention_shifted_scaled_scores) + + if DEBUG_CORE: + print("exp_scores:", exp_scores, exp_scores.shape) + + # Sum of exponentials + sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) + if DEBUG_CORE: + print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) + if causal: + # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly + sum_exp_scores = torch.where( + sum_exp_scores == 0, + torch.ones_like(sum_exp_scores), + sum_exp_scores + ) + if DEBUG_CORE: + print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) + + # Compute softmax probabilities + p = exp_scores / sum_exp_scores + + if DEBUG_CORE: + print("softmax:", p, p.shape) + + # apply dropout if specified + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + # Apply dropout mask and scale + # Set -1 for dropped positions and 1 for kept positions in exp_scores + sd_mask = torch.where(dropout_mask, exp_scores, -exp_scores) + p = torch.where(dropout_mask, p , torch.zeros_like(p)) * dropout_scale + if DEBUG_CORE: + print("softmax after dropout:", p) + print("sd_mask:", sd_mask) + else: + sd_mask = exp_scores + + # Compute log-sum-exp + if use_exp2: + LN2 = math.log(2) + RCP_LN = 1 / math.log(2) + max_scores_base2 = max_scores * RCP_LN + softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores) + softmax_lse = softmax_lse_base2 * LN2 + softmax_lse.squeeze_(-1) + else: + softmax_lse = max_scores + torch.log(sum_exp_scores) + softmax_lse = softmax_lse.squeeze(-1) + + if DEBUG_CORE: + print("softmax_lse:", softmax_lse, softmax_lse.shape) + + # Compute output + o = torch.matmul(p, v) + if DEBUG_CORE: + print("o:", o, o.shape) + + # cast back to original dtype + o = o.to(torch.float16) + # softmax_lse = softmax_lse.to(torch.float16) # NOTE: if you cast lse to fp16 it cause accuracy issues. keep fp32 + sd_mask = sd_mask.to(torch.float16) + + return o, softmax_lse, sd_mask + +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2): + """Compute reference output and softmax_lse using PyTorch's built-in function""" + + # Ensure the layout is 'bhsd' + if layout == "bshd": + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + elif layout != "bhsd": + raise ValueError(f"Unknown layout {layout}") + + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape q to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + else: + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) + + # Call the core attention function + o, softmax_lse, sd_mask = attention_forward_core_ref_impl( + q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2 + ) + + if group_size != 1: + # Reshape outputs back to original dimensions + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_k, group_size, seq_len_q, seq_len_k) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) + else: + # Standard case + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) + + # Restore original layout if necessary + if layout == "bshd": + o = o.transpose(1, 2) + + return o, softmax_lse, sd_mask + + +def attention_varlen_forward_pytorch_ref_impl( + q, + k, + v, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2 +): + # Ensure the layout is 'thd' + if layout != 'thd': + raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") + + batch_size = cu_seqlens_q.shape[0] - 1 + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_dim = q.shape[2] + + # Pre-allocate outputs + total_L_q = q.shape[0] + total_L_k = k.shape[0] + + o = torch.empty((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) + softmax_lse = torch.empty((total_L_q, nheads_q), dtype=torch.float32, device=q.device) + sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) + + # Compute group_size for MQA/GQA handling + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + for i in range(batch_size): + # Get the start and end indices for the current sequence + start_q = cu_seqlens_q[i].item() + end_q = cu_seqlens_q[i + 1].item() + start_k = cu_seqlens_k[i].item() + end_k = cu_seqlens_k[i + 1].item() + + seqlen_q = end_q - start_q + seqlen_k = end_k - start_k + + if DEBUG: + print(f"Batch {i} with seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, Hq= {nheads_q}, Hk = {nheads_k}") + + # Extract q_i, k_i, v_i + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + + # Permute to [nheads, L_q_i, head_dim] + q_i = q_i.permute(1, 0, 2) + k_i = k_i.permute(1, 0, 2) + v_i = v_i.permute(1, 0, 2) + + # Handle MQA/GQA by adjusting shapes based on group_size + if group_size != 1: + # Reshape q_i to [nheads_k, group_size, L_q_i, head_dim] + q_i = q_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(1).expand(-1, group_size, -1, -1) + v_i = v_i.unsqueeze(1).expand(-1, group_size, -1, -1) + # Flatten the first two dimensions for computation + q_i = q_i.reshape(nheads_k * group_size, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + else: + # Standard case + q_i = q_i.reshape(nheads_q, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) + + # Call the core attention function for this sequence + o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, use_exp2) + + # Reshape outputs back to original dimensions + if group_size != 1: + # Reshape outputs to [nheads_k, group_size, seqlen_q, head_dim] + o_i = o_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Combine the first two dimensions back to nheads_q + o_i = o_i.reshape(nheads_q, seqlen_q, head_dim) + # Reshape softmax_lse_i similarly + softmax_lse_i = softmax_lse_i.reshape(nheads_k, group_size, seqlen_q) + softmax_lse_i = softmax_lse_i.reshape(nheads_q, seqlen_q) + else: + # Outputs are already in the correct shape + pass + + # Convert back to 'thd' layout + o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] + sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] + + # Place outputs in pre-allocated tensors + o[start_q:end_q, :, :] = o_i + softmax_lse[start_q:end_q, :] = softmax_lse_i + sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i + + return o, softmax_lse, sd_mask + + + +def attention_forward_pytorch_ref_impl( + q, + k, + v, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2 + ): + if DEBUG: + print() + print("attention_forward_pytorch_ref_impl") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("sm_scale:", sm_scale) + print("causal:", causal) + print("layout:", layout) + print("cu_seqlens_q:", cu_seqlens_q) + print("cu_seqlens_k:", cu_seqlens_k) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) + print("use_exp2:", use_exp2) + + # compute reference + if layout == "thd": + o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( + q.clone(), + k.clone(), + v.clone(), + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2, + ) + else: + o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl(q.clone(), + k.clone(), + v.clone(), + sm_scale, + causal, + layout, + dropout_p, + philox_seed, + philox_offset, + use_exp2) + + if DEBUG: + print() + print("attention_forward_pytorch_ref_impl outputs") + print("o:", o_ref, o_ref.shape) + print("softmax_lse:", softmax_lse_ref, softmax_lse_ref.shape) + print("sd_mask:", sd_mask_ref, sd_mask_ref.shape if sd_mask_ref is not None else None) + + return o_ref, softmax_lse_ref, sd_mask_ref diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py new file mode 100644 index 000000000..d4b736e48 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -0,0 +1,610 @@ +import torch +import os +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_split_oneKernel import attention_prefill_backward_triton_split_oneKernel_impl +from .fwd_decode import attention_decode_forward_triton_impl +from .fwd_ref import attention_forward_pytorch_ref_impl +from .bwd_ref import attention_backward_pytorch_ref_impl +from .utils import MetaData, get_shape_from_layout, DEBUG, USE_SINGLE_BWD_KERNEL +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb + +USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') + +def fwd(q, + k, + v, + o, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax, + gen_, + descale_q, + descale_k, + descale_v, + descale_p): + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("return_softmax:", return_softmax) + + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k.shape[1] + metadata.layout = "bshd" + if return_softmax: + metadata.return_scores = True + + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + metadata.need_dropout(dropout_p) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + else: + rng_state = None + + # check arguments + metadata.check_args(q, k, v, o) + + # call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + output, softmax_lse, sd_mask = attention_forward_pytorch_ref_impl( + q, + k, + v, + metadata.sm_scale, + metadata.causal, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.use_exp2) + o.copy_(output) + else: + if DEBUG: + print("Using Triton implementation") + output, softmax_lse, sd_mask = attention_prefill_forward_triton_impl( + q, + k, + v, + o, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_p) + + if DEBUG: + print("fwd outputs") + print("o:", o, o.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("exp_scores:", sd_mask, sd_mask.shape if sd_mask is not None else None ) + + return o, softmax_lse, sd_mask, rng_state + +def bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + deterministic, + gen_, + rng_state, +): + # NOTE: this might have perf costs + dq.zero_() + dk.zero_() + dv.zero_() + + if DEBUG: + print() + print("flash_attn_triton_amd.py::bwd") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + if dropout_p > 0.0: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + # call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + False, + ) + dq.copy_(dq_ref) + dk.copy_(dk_ref) + dv.copy_(dv_ref) + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + if USE_SINGLE_BWD_KERNEL: + bwd = attention_prefill_backward_triton_impl + else: + # bwd = attention_prefill_backward_triton_split_impl + bwd = attention_prefill_backward_triton_split_oneKernel_impl + _, _, _, delta_triton, _, _ = bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + True, + ) + delta = delta_triton + + if DEBUG: + print("bwd outputs") + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + return dq, dk, dv, delta + +def varlen_fwd( + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table_, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax, + gen_, + descale_q, + descale_k, + descale_v, + descale_p): + + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("gen_:", gen_) + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + if return_softmax: + metadata.return_scores = True + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata + + # get shapes + batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + metadata.need_dropout(dropout_p) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + else: + rng_state = None + + # Check arguments + metadata.check_args(q, k, v, o) + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + # call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + output, softmax_lse, sd_mask = attention_forward_pytorch_ref_impl( + q, + k, + v, + metadata.sm_scale, + metadata.causal, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.use_exp2) + o.copy_(output) + else: + if DEBUG: + print("Using Triton implementation") + output, softmax_lse, sd_mask = attention_prefill_forward_triton_impl( + q, + k, + v, + o, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_p) + if DEBUG: + print("varlen_fwd outputs") + print("o:", o, o.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) + + + return o, softmax_lse, sd_mask, rng_state + +def varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + window_size_left, + window_size_right, + softcap, + deterministic, + gen_, + rng_state, +): + if DEBUG: + print() + print("varlen_bwd") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + if dropout_p > 0.0: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + # call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + True, + ) + dq.copy_(dq_ref) + dk.copy_(dk_ref) + dv.copy_(dv_ref) + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + if USE_SINGLE_BWD_KERNEL: + bwd = attention_prefill_backward_triton_impl + else: + # bwd = attention_prefill_backward_triton_split_impl + bwd = attention_prefill_backward_triton_split_oneKernel_impl + dq_triton, dk_triton, dv_triton, delta_triton, _, _ = bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + False, + ) + delta = delta_triton + + if DEBUG: + print("varlen_bwd outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + + return dq, dk, dv, delta + +def fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + out, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + rotary_interleaved, + num_splits): + + if out is None: + out = torch.empty_like(q) + + # fill metadata + metadata = MetaData(sm_scale=softmax_scale) + metadata.layout = "bshd" + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k_cache.shape[1] + metadata.cache_seqlens = cache_seqlens + metadata.cache_batch_idx = cache_batch_idx + + if k is not None and v is not None: + metadata.new_kv = True + metadata.seqlen_new = k.shape[1] + metadata.k_new = k + metadata.v_new = v + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + batch, _ , nheads_q, _= q.shape + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + # rotary boolean + apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) + if apply_rotary: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Rotary Embedding Implementation + if apply_rotary: + if metadata.causal: # NOTE: when support is added. Add `or metadata.local` + q_ro = apply_rotary_emb( + q, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + metadata.k_new, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + + q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + + # launch kernel + # TODO: pass output as an arg. Maybe we are copying output which is causing slow down + output, softmax_lse = attention_decode_forward_triton_impl( + q, + k_cache, + v_cache, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.new_kv, + metadata.k_new, + metadata.v_new, + ) + return output, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_torch.py b/flash_attn/flash_attn_triton_amd/interface_torch.py new file mode 100644 index 000000000..6c0f9d029 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_torch.py @@ -0,0 +1,100 @@ +import os +import torch +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_split_oneKernel import attention_prefill_backward_triton_split_oneKernel_impl +from .fwd_decode import attention_decode_forward_triton_impl +from .utils import USE_SINGLE_BWD_KERNEL + + +class _attention_prefill(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, o, metadata): + output, softmax_lse, sd_mask = attention_prefill_forward_triton_impl( + q, + k, + v, + o, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2) + + ctx.save_for_backward(q, k, v, o, softmax_lse) + ctx.sm_scale = metadata.sm_scale + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = metadata.philox_seed + ctx.philox_offset = metadata.philox_offset + ctx.return_scores = metadata.return_scores + ctx.layout = metadata.layout + ctx.use_exp2 = metadata.use_exp2 + return output, softmax_lse, sd_mask + + @staticmethod + def backward(ctx, do, *args): + q, k, v, o, softmax_lse = ctx.saved_tensors + if USE_SINGLE_BWD_KERNEL: + bwd = attention_prefill_backward_triton_impl + else: + bwd = attention_prefill_backward_triton_split_oneKernel_impl + # bwd = attention_prefill_backward_triton_split_impl + return bwd( + do, + q, + k, + v, + o, + softmax_lse, + None, + None, + None, + ctx.sm_scale, + ctx.alibi_slopes, + ctx.causal, + ctx.layout, + None, + None, + None, + None, + ctx.dropout_p, + ctx.philox_seed, + ctx.philox_offset, + ctx.use_exp2 + ) + +attention_prefill = _attention_prefill.apply + + +class _attention_decode(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, metadata): + output, softmax_lse = attention_decode_forward_triton_impl( + q, + k, + v, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.new_kv, + metadata.k_new, + metadata.v_new, + ) + return output, softmax_lse + +attention_decode = _attention_decode.apply diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py new file mode 100644 index 000000000..8527ab5ff --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -0,0 +1,1278 @@ +import torch +import pytest +import numpy as np +from flash_attn import flash_attn_func, flash_attn_varlen_func + +from .utils import DEBUG, DEBUG_TRITON, DEBUG_TRITON_DETAIL, MetaData, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref, get_arch, arch_supports_fp8 +from .interface_torch import attention_prefill, attention_decode +from .fwd_ref import attention_forward_pytorch_ref_impl +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_split_oneKernel import attention_prefill_backward_triton_split_oneKernel_impl +from .bwd_ref import attention_backward_pytorch_ref_impl +from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 + +# set print options +torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) +np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) + +# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html +ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. +# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues +# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs +# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs +# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 +ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # test pass with dropout and causal in fp8 +EQUAL_NAN = True + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd_prefill(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 15498, 2, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + # TODO: This config fails. Disabled until triaged and fixed. + # (2, 16, 1020, 987, 128), + # (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +def test_op_fwd_prefill_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 8192, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (8, 48, 4096, 64), + (4, 48, 8192, 64), + (4, 48, 128, 128), + (4, 48, 4096, 128), + (4, 48, 16384, 128), + (4, 16, 1024, 128), + (4, 16, 8192, 128), + (32, 48, 8192, 128) + ] + ) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) + + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention_prefill(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention_prefill(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + # smallest config test + (1, 1, 16, 16, 64), # pass on new # fail on old + (1, 1, 32, 32, 64), # pass on new # fail on old + (1, 1, 64, 64, 16), # pass # smallest head_size = 16 + (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 + (1, 1, 128, 128, 64), # pass + (1, 1, 256, 256, 64), # pass + (1, 1, 512, 512, 64), # pass + # failing FA + (1, 1, 256, 512, 16), + # old tests that work + (4, 48, 1024, 1024, 64), # pass + (4, 48, 2048, 2048, 64), # pass + (2, 48, 4096, 4096, 64), # pass + (1, 16, 1024, 1024, 64), # pass + (1, 16, 1024, 1024, 128), # pass + # old tests that were commented out + # (1, 16, 8192, 8192, 63), + # (1, 16, 1022, 1022, 64), +]) +# @pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('torch_sdpa_test', [False]) +# @pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize('use_alibi', [False, True]) +@pytest.mark.parametrize('use_alibi', [False]) +def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): + torch.manual_seed(20) + + DEBUG_INPUT = False + + # seqlens + seqlen_q = N_CTX_Q + seqlen_k = N_CTX_K + + # setup up metadata + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + input_metadata.layout = "bhsd" + + dropout_p = 0 + if DEBUG_INPUT: + q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, 1, seqlen_q, 1).expand(Z, H, seqlen_q, D_HEAD).requires_grad_() + k = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() + v = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() + o = torch.zeros_like(q) + else: + # Generate random inputs + q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + + if DEBUG_INPUT: + dout = torch.ones_like(q) + else: + dout = torch.randn_like(q) + + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + if DEBUG: + print("tri_out:", tri_out) + print("ref_out:",ref_out ) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + if DEBUG: + print("ref_dv:", ref_dv) + print("tri_dv:", tri_dv) + print("ref_dk:", ref_dk) + print("tri_dk:", tri_dk) + print("ref_dq:", ref_dq) + print("tri_dq:", tri_dq) + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) +@pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false +@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues +def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, use_exp2, DEBUG_INPUT): + dtype = torch.float16 + torch.manual_seed(0) + alibi_slopes = None + device = "cuda" + + if layout == "thd": + q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + else: + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + if DEBUG_INPUT: + output_triton = torch.zeros_like(q).contiguous() + else: + output_triton = torch.empty_like(q) + + if DEBUG: + if HQ // HK != 1: + print("MQA/GQA") + else: + print("MHA") + + # update metadata + metadata.use_exp2 = use_exp2 + if causal: + metadata.need_causal() + + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + if dropout_p > 0.0: + metadata.need_dropout(dropout_p) + + + # call Triton's forward implementation directly + output_triton, softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k, + v, + output_triton, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2) + + output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q.clone(), + k.clone(), + v.clone(), + metadata.sm_scale, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2 + ) + + if DEBUG: + print() + print("Compare Triton Impl with refernce Pytorch Impl") + + # this can be set to true manually or when using dropout + if metadata.return_scores: + if DEBUG: + print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) + print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) + torch.testing.assert_close(sd_mask_triton, sd_mask_ref, atol=ATOL, rtol=RTOL) + + if DEBUG: + print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) + print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) + torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) + + if DEBUG: + print("output_triton:", output_triton, output_triton.shape) + print("output_ref:", output_ref, output_ref.shape) + torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 4, 4, 4), + (2, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 8, 1, 2, 4, 16), + (1, 16, 1, 2, 4, 16), + (1, 32, 1, 2, 4, 16), + (1, 64, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 4, 4, 16), + (2, 1, 1, 4, 4 , 16), + (4, 6, 6, 8, 8 , 16), + (1, 1, 1, 4, 4, 32), + (1, 1, 1, 16, 16, 16), + (1, 1, 1, 32, 32, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 128, 16), + (1, 1, 1, 64, 64, 32), + (1, 1, 1, 64, 128, 32), + (1, 1, 1, 128, 128, 64), + (1, 1, 1, 128, 256, 45), + (1, 1, 1, 113, 203, 192), + (1, 1, 1, 256, 256, 64), + (1, 1, 1, 256, 512, 16), + (1, 1, 1, 512, 512, 64), + (1, 1, 1, 1024, 1024, 64), + # fa configs + (2, 2, 2, 128, 128, 65), + (2, 2, 2, 128, 128, 224), + (4, 6, 6, 108, 256, 224), + (1, 1, 1, 256, 512, 16), + # old tests that work + (4, 48, 6, 1024, 1024, 64), + (4, 48, 12, 1024, 1024, 64), + (4, 48, 24, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 73), + (4, 48, 48, 2048, 2048, 64), + (1, 24, 24, 4096, 4096, 64), + (1, 16, 16, 1024, 1024, 64), + (1, 16, 16, 1024, 1024, 128), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal +@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +def test_op_prefill_bwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, use_exp2, layout, sequence_parallel, DEBUG_INPUT): + if get_arch() == "gfx90a": + if layout == "thd" and Z == 4 and HQ == 48 and HK == 48 and N_CTX_Q == 1024 and N_CTX_K == 1024: + pytest.skip("This config doesnot work on MI200 Devices but works on MI300.") + + dtype = torch.float16 + torch.manual_seed(20) # seed from test_op_bwd + + alibi_slopes = None + if layout == "thd": + q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, DEBUG_INPUT=DEBUG_INPUT) + else: + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT=DEBUG_INPUT) + if DEBUG_INPUT: + do = torch.ones_like(q).contiguous() + else: + do = torch.randn_like(q) + + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + if dropout_p > 0.0: + metadata.need_dropout(dropout_p) + + # =============================================== Reference ============================================================== + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q_ref, + k_ref, + v_ref, + metadata.sm_scale, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2 + ) + + + if DEBUG: + if HQ // HK != 1: + print("MQA/GQA") + else: + print("MHA") + + dq = torch.zeros_like(q, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros + if DEBUG_INPUT: + dk = torch.zeros_like(k, dtype=k.dtype) + dv = torch.zeros_like(v, dtype=v.dtype) + else: + dk = torch.empty_like(k, dtype=k.dtype) + dv = torch.empty_like(v, dtype=v.dtype) + + do_ref = do.clone() + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + do_ref, + q_ref, + k_ref, + v_ref, + output_ref, + softmax_lse_ref, + metadata.sm_scale, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2 + ) + + # =============================================== Triton ============================================================== + o = output_ref.clone().contiguous() + softmax_lse = softmax_lse_ref.clone().contiguous() + dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + metadata.sm_scale, + alibi_slopes, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2, + sequence_parallel=sequence_parallel + ) + + # =============================================== Check ============================================================== + if DEBUG: + print() + if DEBUG: + print("delta_triton:", delta_triton, delta_triton.shape) + print("delta_ref:", delta_ref, delta_ref.shape) + torch.testing.assert_close(delta_triton, delta_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + + if DEBUG: + print("dv_triton:", dv_triton, dv_triton.shape) + print("dv_ref:", dv_ref, dv_ref.shape) + torch.testing.assert_close(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + + if DEBUG: + print("dk_triton:", dk_triton, dk_triton.shape) + print("dk_ref:", dk_ref, dk_ref.shape) + torch.testing.assert_close(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + + if DEBUG: + print("dq_triton:", dq_triton, dq_triton.shape) + print("dq_ref:", dq_ref, dq_ref.shape) + torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + + +@pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes()) +def test_op_fwd_decode(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16): + if get_arch() == "gfx90a": + if batch_size == 1 and seqlen_q == 1 and seqlen_k >= 65536: + pytest.skip("This config doesnot work on MI200 Devices but works on MI300.") + + torch.manual_seed(20) + query_group_head_size = (group_q + group_k - 1) // group_k + q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) + v = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) + scale = 1 / dim**0.5 + input_metadata = MetaData(sm_scale=scale) + input_metadata.layout = "bsghd" + tri_out, _ = attention_decode(q, k, v, input_metadata) + + q = q.reshape([batch_size, seqlen_q, -1, dim]).permute(0, 2, 1, 3) + k = k.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) + v = v.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + pytest.skip("Decode kernel doesnot support quantization yet") + torch.manual_seed(2) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) + quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) + scale = 1 / K**0.5 + input_metadata = MetaData(sm_scale=scale) + input_metadata.layout = "bsghd" + tri_out, _ = attention_decode(q, quant_k, quant_v, input_metadata) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.25]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + layout = "bshd" + + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # NOTE: use bfp16 becasue it fp32 trunacted + # launch kernel in fp16 + q_bfp16 = q.clone().to(torch.bfloat16) + k_bfp16 = k.clone().to(torch.bfloat16) + v_bfp16 = v.clone().to(torch.bfloat16) + out_bfp16, lse_bfp16, S_dmask_bfp16 = flash_attn_func( + q_bfp16, + k_bfp16, + v_bfp16, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_bfp16", out_bfp16) + print("lse_bfp16", lse_bfp16) + print("S_dmask_bfp16", S_dmask_bfp16) + + # compute p for descaling + batch, _ , nheads_q, dim = q.shape + _, _ , nheads_k, _ = k.shape + + # compute max for each batch-head pair across seqlen and dim + q_max = torch.maximum(q.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + k_max = torch.maximum(k.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + v_max = torch.maximum(v.abs().amax(dim=(1, 3)), torch.tensor(1e-9)).unsqueeze(1).unsqueeze(-1) + + # scale values to fp8 range + type_max = torch.finfo(torch.float8_e4m3fnuz).max + q_fp8 = (q * type_max/ q_max).to(torch.float8_e4m3fnuz) + k_fp8 = (k * type_max/ k_max).to(torch.float8_e4m3fnuz) + v_fp8 = (v * type_max/ v_max).to(torch.float8_e4m3fnuz) + + # compute descale values + descale_q = q_max / type_max + descale_k = k_max / type_max + descale_v = v_max / type_max + descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device) + + # launch kernel in fp8 + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_func( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, + ) + if DEBUG: + print("out_fp8", out_fp8) + print("lse_fp8", lse_fp8) + print("S_dmask_fp8", S_dmask_fp8) + + if DEBUG: + print("out_bfp16:", out_bfp16, out_bfp16.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + + torch.testing.assert_close(out_bfp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.25]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_op_prefill_varlen_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + layout = "thd" + + q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, torch.float32, DEBUG_INPUT=DEBUG_INPUT) + + # launch kernel in fp16 + q_bfp16 = q.clone().to(torch.bfloat16) + k_bfp16 = k.clone().to(torch.bfloat16) + v_bfp16 = v.clone().to(torch.bfloat16) + out_bfp16, lse_bfp16, S_dmask_bfp16 = flash_attn_varlen_func( + q_bfp16, + k_bfp16, + v_bfp16, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_bfp16", out_bfp16) + print("lse_bfp16", lse_bfp16) + print("S_dmask_bfp16", S_dmask_bfp16) + + + if DEBUG: + print("q:", q, q.shape) + print("k:", k, k.shape) + + # thd + batch = len(metadata.cu_seqlens_q) - 1 + nheads_q = q.size(1) + nheads_k = k.size(1) + + if DEBUG: + print("batch:", batch) + print("nheads_q:", nheads_q) + print("nheads_k:", nheads_k) + + q_maxes = [] + k_maxes = [] + v_maxes = [] + for i in range(batch): + q_start = metadata.cu_seqlens_q[i] + q_end = metadata.cu_seqlens_q[i + 1] + k_start = metadata.cu_seqlens_k[i] + k_end = metadata.cu_seqlens_k[i + 1] + + # compute max for each batch-head pair across seqlen and dim + q_max = torch.maximum(q[q_start:q_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + k_max = torch.maximum(k[k_start:k_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + v_max = torch.maximum(v[k_start:k_end].abs().amax(dim=(0,2)), torch.tensor(1e-9)).unsqueeze(-1) + + q_maxes.append(q_max) + k_maxes.append(k_max) + v_maxes.append(v_max) + q_maxes = torch.stack(q_maxes) + k_maxes = torch.stack(k_maxes) + v_maxes = torch.stack(v_maxes) + if DEBUG: + print("q", q, q.shape) + print("q_maxes:", q_maxes, q_maxes.shape) + print("k", k, k.shape) + print("k_maxes:", k_maxes, k_maxes.shape) + + # ---------------------------------------------------------------- + # --- FP8 conversion part --- + # ---------------------------------------------------------------- + type_max = torch.finfo(torch.float8_e4m3fnuz).max + q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fnuz) + k_fp8 = torch.empty_like(k, dtype=torch.float8_e4m3fnuz) + v_fp8 = torch.empty_like(v, dtype=torch.float8_e4m3fnuz) + for i in range(batch): + q_start = metadata.cu_seqlens_q[i] + q_end = metadata.cu_seqlens_q[i + 1] + k_start = metadata.cu_seqlens_k[i] + k_end = metadata.cu_seqlens_k[i + 1] + + # shape [heads_q, 1], broadcast to [1, heads_q, 1] + q_scale = (type_max / q_maxes[i]).unsqueeze(0) # => [1, HQ, 1] + k_scale = (type_max / k_maxes[i]).unsqueeze(0) # => [1, HK, 1] + v_scale = (type_max / v_maxes[i]).unsqueeze(0) # => [1, HK, 1] + + # q, k, v are [L, heads, dim] slices + q_slice = q[q_start:q_end] # [seq_len_i, HQ, dim] + k_slice = k[k_start:k_end] # [seq_len_i, HK, dim] + v_slice = v[k_start:k_end] # [seq_len_i, HK, dim] + + # Convert them to FP8 + q_fp8[q_start:q_end] = (q_slice * q_scale).to(torch.float8_e4m3fnuz) + k_fp8[k_start:k_end] = (k_slice * k_scale).to(torch.float8_e4m3fnuz) + v_fp8[k_start:k_end] = (v_slice * v_scale).to(torch.float8_e4m3fnuz) + + if DEBUG: + print("q_fp8:", q_fp8, q_fp8.shape) + print("k_fp8:", k_fp8, k_fp8.shape) + + # compute descale values + descale_q = q_maxes / type_max + descale_k = k_maxes / type_max + descale_v = v_maxes / type_max + descale_p = torch.full_like(descale_q, 1.0 / type_max, dtype=torch.float32, device=q.device) + + # launch kernel in fp8 + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_func( + q_fp8, + k_fp8, + v_fp8, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_p=descale_p, + ) + if DEBUG: + print("out_fp8", out_fp8) + print("lse_fp8", lse_fp8) + print("S_dmask_fp8", S_dmask_fp8) + + if DEBUG: + print("out_bfp16:", out_bfp16, out_bfp16.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + + torch.testing.assert_close(out_bfp16.to(torch.float32), out_fp8.to(torch.float32), atol=ATOL_fp8, rtol=RTOL_fp8) + + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 4, 4, 4), + (2, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 8, 1, 2, 4, 16), + (1, 16, 1, 2, 4, 16), + (1, 32, 1, 2, 4, 16), + (1, 64, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 4, 4, 16), + (2, 1, 1, 4, 4 , 16), + (4, 6, 6, 8, 8 , 16), + (1, 1, 1, 4, 4, 32), + (1, 1, 1, 16, 16, 16), + (1, 1, 1, 32, 32, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 128, 16), + (1, 1, 1, 64, 64, 32), + (1, 1, 1, 64, 128, 32), + (1, 1, 1, 128, 128, 64), + (1, 1, 1, 128, 256, 45), + (1, 1, 1, 113, 203, 192), + (1, 1, 1, 256, 256, 64), + (1, 1, 1, 256, 512, 16), + (1, 1, 1, 512, 512, 64), + (1, 1, 1, 1024, 1024, 64), + # fa configs + (2, 2, 2, 128, 128, 65), + (2, 2, 2, 128, 128, 224), + (4, 6, 6, 108, 256, 224), + (1, 1, 1, 256, 512, 16), + # old tests that work + (4, 48, 6, 1024, 1024, 64), + (4, 48, 12, 2048, 1024, 64), + (4, 48, 24, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 73), + (4, 48, 48, 2048, 2048, 64), + (1, 24, 24, 4096, 4096, 64), + (1, 16, 16, 1024, 1024, 64), + (1, 16, 16, 1024, 1024, 128), + # testcase new + # seqlen q == k + (1, 1, 1, 2, 2, 2), # small enough to debug + (1, 1, 1, 128, 128, 32), # only one block + (1, 1, 1, 127, 127, 32), # only one block but with masking + (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 68), # generic masking on q, k and head + (4, 1, 1, 512, 512, 128), # batch > 1 + (4, 8, 2, 512, 512, 128), # GQA + (4, 8, 2, 512, 512, 68), # non-power-of-2 head_dim + (4, 8, 2, 500, 500, 68), # comprehensive case for seqlen q == k + # seqlen q > k + (1, 1, 1, 64, 32, 8), # seqlen_q > seqlen_k + (1, 1, 1, 192, 128, 32), # seqlen_q > seqlen_k + (4, 8, 2, 1024, 512, 68), # seqlen_q < seqlen_k + (1, 1, 1, 729, 516, 68), # seqlen_q > seqlen_k + (16, 16, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k + # seqlen q < k + (1, 1, 1, 32, 64, 8), # seqlen_q > seqlen_k + (1, 1, 1, 128, 192, 32), # seqlen_q < seqlen_k + (4, 8, 2, 512, 1024, 68), # seqlen_q < seqlen_k + (1, 1, 1, 200, 413, 1), # seqlen_q < seqlen_k + (1, 1, 1, 782, 1546, 1), # seqlen_q < seqlen_k + (16, 16, 4, 1528, 2753, 68), # a comprehensive seqlen_q < seqlen_k + +# varlen +# dropout +# direct comparison among tutorial, Michael's implementation bwd and this one + +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.2]) +@pytest.mark.parametrize('use_exp2', [True, False]) # FIXME: using exp2 causes issue when used with causal +# @pytest.mark.parametrize('layout', ["bhsd"]) +@pytest.mark.parametrize('layout', ["bhsd", "thd"]) +@pytest.mark.parametrize('sequence_parallel', [True]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +def test_op_prefill_bwd_split_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, use_exp2, layout, sequence_parallel, DEBUG_INPUT): + dtype = torch.float16 + torch.manual_seed(20) # seed from test_op_bwd + + alibi_slopes = None + if layout == "thd": + q, k, v, metadata = varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, DEBUG_INPUT=DEBUG_INPUT) + else: + q, k, v, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT=DEBUG_INPUT) + if DEBUG_INPUT: + do = torch.ones_like(q).contiguous() + else: + do = torch.randn_like(q) + + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + if dropout_p > 0.0: + metadata.need_dropout(dropout_p) + + # print("from the very beginning") + # print("q:", q.shape) + # print("k:", k.shape) + # print("v:", v.shape) + + # =============================================== Reference ============================================================== + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q_ref, + k_ref, + v_ref, + metadata.sm_scale, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2 + ) + + + if DEBUG: + if HQ // HK != 1: + print("MQA/GQA") + else: + print("MHA") + + dq = torch.zeros_like(q, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros + if DEBUG_INPUT: + dk = torch.zeros_like(k, dtype=k.dtype) + dv = torch.zeros_like(v, dtype=v.dtype) + else: + dk = torch.empty_like(k, dtype=k.dtype) + dv = torch.empty_like(v, dtype=v.dtype) + + do_ref = do.clone() + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + do_ref, + q_ref, + k_ref, + v_ref, + output_ref, + softmax_lse_ref, + metadata.sm_scale, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2 + ) + # =============================================== Triton ============================================================== + o = output_ref.clone().contiguous() + softmax_lse = softmax_lse_ref.clone().contiguous() + # dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_split_impl( + dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_split_oneKernel_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + metadata.sm_scale, + alibi_slopes, + causal, + layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + use_exp2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # =============================================== Check ============================================================== + if DEBUG: + print() + if DEBUG: + print("delta_triton:", delta_triton, delta_triton.shape) + print("delta_ref:", delta_ref, delta_ref.shape) + if DEBUG: + dim_names = ["batch", "qhead", "seqlen_kv", "head_dim"] + mismatch = torch.where(torch.isclose(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) != 1) + num_error_dv = mismatch[0].numel() + if num_error_dv > 0: + print(f"\nnumber of mismatch in dv: {num_error_dv}") + for m, name in zip(mismatch, dim_names): + print(f"{name}: {m.unique().cpu()}") + dim_names = ["batch", "kvhead", "seqlen_kv", "head_dim"] + mismatch = torch.where(torch.isclose(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) != 1) + num_error_dk = mismatch[0].numel() + if num_error_dk > 0: + print(f"\nnumber of mismatch in dk: {num_error_dk}") + for m, name in zip(mismatch, dim_names): + print(f"{name}: {m.unique().cpu()}") + dim_names = ["batch", "qhead", "seqlen_q", "head_dim"] + mismatch = torch.where(torch.isclose(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) != 1) + num_error_dq = mismatch[0].numel() + if num_error_dq > 0: + print(f"\nnumber of mismatch in dq: {num_error_dq}") + for m, name in zip(mismatch, dim_names): + print(f"{name}: {m.unique().cpu()}") + + torch.testing.assert_close(delta_triton, delta_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + torch.testing.assert_close(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + torch.testing.assert_close(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) + torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py new file mode 100644 index 000000000..d55dc5bad --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -0,0 +1,402 @@ + +import csv +import math +import torch +import os +import random +import functools +import triton +import triton.language as tl + +AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') +DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') +PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') +USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes') +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +DEBUG_TRITON = os.environ.get('DEBUG_TRITON', '0').lower() in ('1', 'true', 'yes') and os.environ.get('TRITON_INTERPRET', '0').lower() in ('1', 'true', 'yes') +DEBUG_TRITON_DETAIL = os.environ.get('DEBUG_TRITON_DETAIL', '0').lower() in ('1', 'true', 'yes') +if USE_TRITON_ROCM: # TODO remove this + random.seed(42) +DROPOUT_USE_PYTORCH = False +DROPOUT_DUMP = False + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + layout = None + cache_seqlens = None + cache_batch_idx = None + new_kv = False + seqlen_new = None + k_new = None + v_new = None + return_scores= False + dropout_p= 0.0 + philox_seed, philox_offset = None, None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. + # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. + use_exp2 = False + rotary_sin = None + rotary_cos = None + rotary_interleaved = False + rotary_conjunction = False + + + def __repr__(self) -> str: + return (f"MetaData(\n" + f" sm_scale={self.sm_scale},\n" + f" cu_seqlens_q={self.cu_seqlens_q},\n" + f" cu_seqlens_k={self.cu_seqlens_k},\n" + f" max_seqlens_q={self.max_seqlens_q},\n" + f" max_seqlens_k={self.max_seqlens_k},\n" + f" bias={self.bias},\n" + f" alibi_slopes={self.alibi_slopes},\n" + f" causal={self.causal},\n" + f" num_contexts={self.num_contexts},\n" + f" varlen={self.varlen},\n" + f" layout={self.layout},\n" + f" cache_seqlens={self.cache_seqlens},\n" + f" cache_batch_idx={self.cache_batch_idx},\n" + f" new_kv={self.new_kv},\n" + f" seqlen_new={self.seqlen_new},\n" + f" k_new={self.k_new},\n" + f" v_new={self.v_new},\n" + f" dropout_p={self.dropout_p},\n" + f" return_scores={self.return_scores}\n" + f" use_exp2={self.use_exp2}\n" + f")") + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): + self.rotary_sin = sin + self.rotary_cos = cos + self.rotary_interleaved = rotary_interleaved + self.rotary_conjunction = rotary_conjunction + + def need_dropout(self, dropout_p): + self.dropout_p = dropout_p + self.return_scores = True + self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # assert not self.return_scores + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False): + torch.manual_seed(20) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, f'Got unsupported tensor layout: {layout}' + + if DEBUG_INPUT: + if layout == "bhsd": + q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + elif layout == "bshd": + q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + else: + q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + return q, k, v, input_metadata + + +def random_seqlens_composition(N, Z): + # generate a random composition of N into Z positive parts. + idx = torch.randperm(N - 1)[: Z - 1] + 1 + idx, _ = torch.sort(idx) + breakpoints = torch.cat([ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([N], dtype=torch.long), + ]) + seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) + return seqlens + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False): + torch.manual_seed(20) + + # Random or equal sequence lengths based on 'equal_seqlens' flag + if not equal_seqlens: + seqlens_q = random_seqlens_composition(N_CTX_Q, Z) + seqlens_k = random_seqlens_composition(N_CTX_K, Z) + else: + seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32) + seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32) + + # calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)]) + cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32) + cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32) + + # total lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + + if DEBUG_INPUT: + sm_scale = 1.0 + + q = torch.empty(total_q, HQ, D_HEAD, dtype=dtype, device=device) + k = torch.empty(total_k, HK, D_HEAD, dtype=dtype, device=device) + v = torch.empty(total_k, HK, D_HEAD, dtype=dtype, device=device) + for i in range(Z): + q_start = cu_seqlens_q[i].item() + q_end = cu_seqlens_q[i+1].item() + q_length = q_end - q_start + k_start = cu_seqlens_k[i].item() + k_end = cu_seqlens_k[i+1].item() + k_length = k_end - k_start + + + q[q_start:q_end, :, :] = ( + torch.arange(q_length, dtype=dtype, device=device) + .view(q_length, 1, 1) + .expand(q_length, HQ, D_HEAD) + ) + k[k_start:k_end, :, :] = ( + torch.arange(k_length, dtype=dtype, device=device) + .view(k_length, 1, 1) + .expand(k_length, HK, D_HEAD) + ) + v[k_start:k_end, :, :] = ( + torch.arange(k_length, dtype=dtype, device=device) + .view(k_length, 1, 1) + .expand(k_length, HK, D_HEAD) + ) + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + + else: + # Initialize q, k, v with random values + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() + sm_scale = D_HEAD ** -0.5 + + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + + return q, k, v, input_metadata + + +def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + if layout == 'bhsd': + batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape + batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape + elif layout == 'bshd': + batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape + batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape + elif layout == 'thd': + batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] + else: + assert False, "Got unsupported layout." + + # assert + assert batch_q == batch_k + assert head_size_q == head_size_k + + return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k + +def get_strides_from_layout(q, k, v, o, layout): + if layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + +def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + +def create_dropout_mask(dropout_p, shape, seed): + device = "cuda" + rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) + return rand_vals > dropout_p + +def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): + device = "cuda" + qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) + klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) + max_qlen = qlens.max() + max_klen = klens.max() + dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) + for b in range(batch): + qlen = qlens[b] + klen = klens[b] + rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) + submask = rand_vals > dropout_p + dropout_mask[b, :, :qlen, :klen] = submask + + return dropout_mask + +def write_dropout_mask(x, tensor_name = "tensor"): + batch, head, seqlen_m, seqlen_n = x.shape + x = x.tolist() + + with open(f'{tensor_name}.csv', 'w') as f: + writer = csv.writer(f) + for b in range(batch): + for h in range(head): + dropout_mask = x[b][h] + if True: + BLOCK_M = 64 + BLOCK_N = 64 + + # Calculate number of blocks in each dimension + m_blocks = math.ceil(seqlen_m / BLOCK_M) + n_blocks = math.ceil(seqlen_n / BLOCK_N) + + # Process each block + for m_block in range(m_blocks): + # Calculate row range for current block + row_start = m_block * BLOCK_M + row_end = min(row_start + BLOCK_M, seqlen_m) + + for n_block in range(n_blocks): + # Calculate column range for current block + col_start = n_block * BLOCK_N + col_end = min(col_start + BLOCK_N, seqlen_n) + + # Extract and write the current block + for row_idx in range(row_start, row_end): + row_data = dropout_mask[row_idx][col_start:col_end] + writer.writerow(row_data) + else: + writer.writerows(dropout_mask) + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + +def get_input_shapes(): + cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] + return cases + +@functools.cache +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + +@functools.cache +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch + +@functools.cache +def is_cdna(): + return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942') + +@functools.cache +def is_rdna(): + return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") + +@functools.cache +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') \ No newline at end of file diff --git a/setup.py b/setup.py index 3184c91dd..eb446905e 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" - +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" def get_platform(): """ @@ -139,7 +139,8 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if IS_ROCM: - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) + if not USE_TRITON_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) else: subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) @@ -304,108 +305,112 @@ def validate_and_update_archs(archs): ) ) elif not SKIP_CUDA_BUILD and IS_ROCM: - ck_dir = "csrc/composable_kernel" - - #use codegen get code dispatch - if not os.path.exists("./build"): - os.makedirs("build") - - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h - # See https://github.com/pytorch/pytorch/pull/70650 - generator_flag = [] - torch_dir = torch.__path__[0] - if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - check_if_rocm_home_none("flash_attn") - archs = os.getenv("GPU_ARCHS", "native").split(";") - validate_and_update_archs(archs) - - cc_flag = [f"--offload-arch={arch}" for arch in archs] - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # torch._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - torch._C._GLIBCXX_USE_CXX11_ABI = True - - sources = ["csrc/flash_attn_ck/flash_api.cpp", - "csrc/flash_attn_ck/flash_common.cpp", - "csrc/flash_attn_ck/mha_bwd.cpp", - "csrc/flash_attn_ck/mha_fwd_kvcache.cpp", - "csrc/flash_attn_ck/mha_fwd.cpp", - "csrc/flash_attn_ck/mha_varlen_bwd.cpp", - "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( - f"build/fmha_*wd*.cpp" - ) - - rename_cpp_to_cu(sources) - - renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", - "csrc/flash_attn_ck/flash_common.cu", - "csrc/flash_attn_ck/mha_bwd.cu", - "csrc/flash_attn_ck/mha_fwd_kvcache.cu", - "csrc/flash_attn_ck/mha_fwd.cu", - "csrc/flash_attn_ck/mha_varlen_bwd.cu", - "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") - - cc_flag += ["-O3","-std=c++17", - "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", - "-fgpu-flush-denormals-to-zero", - "-DCK_ENABLE_BF16", - "-DCK_ENABLE_BF8", - "-DCK_ENABLE_FP16", - "-DCK_ENABLE_FP32", - "-DCK_ENABLE_FP64", - "-DCK_ENABLE_FP8", - "-DCK_ENABLE_INT8", - "-DCK_USE_XDL", - "-DUSE_PROF_API=1", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - "-D__HIP_PLATFORM_HCC__=1"] - - cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] - - # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 - hip_version = get_hip_version() - if hip_version > Version('5.7.23302'): - cc_flag += ["-fno-offload-uniform-block"] - if hip_version > Version('6.1.40090'): - cc_flag += ["-mllvm", "-enable-post-misched=0"] - if hip_version > Version('6.2.41132'): - cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", - "-mllvm", "-amdgpu-function-calls=false"] - if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): - cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] - - extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": cc_flag + generator_flag, - } - - include_dirs = [ - Path(this_dir) / "csrc" / "composable_kernel" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", - ] + if USE_TRITON_ROCM: + # Skip C++ extension compilation if using Triton Backend + pass + else: + ck_dir = "csrc/composable_kernel" + + #use codegen get code dispatch + if not os.path.exists("./build"): + os.makedirs("build") + + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + check_if_rocm_home_none("flash_attn") + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + sources = ["csrc/flash_attn_ck/flash_api.cpp", + "csrc/flash_attn_ck/flash_common.cpp", + "csrc/flash_attn_ck/mha_bwd.cpp", + "csrc/flash_attn_ck/mha_fwd_kvcache.cpp", + "csrc/flash_attn_ck/mha_fwd.cpp", + "csrc/flash_attn_ck/mha_varlen_bwd.cpp", + "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( + f"build/fmha_*wd*.cpp" + ) - ext_modules.append( - CUDAExtension( - name="flash_attn_2_cuda", - sources=renamed_sources, - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, + rename_cpp_to_cu(sources) + + renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", + "csrc/flash_attn_ck/flash_common.cu", + "csrc/flash_attn_ck/mha_bwd.cu", + "csrc/flash_attn_ck/mha_fwd_kvcache.cu", + "csrc/flash_attn_ck/mha_fwd.cu", + "csrc/flash_attn_ck/mha_varlen_bwd.cu", + "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + + cc_flag += ["-O3","-std=c++17", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + "-DCK_ENABLE_BF16", + "-DCK_ENABLE_BF8", + "-DCK_ENABLE_FP16", + "-DCK_ENABLE_FP32", + "-DCK_ENABLE_FP64", + "-DCK_ENABLE_FP8", + "-DCK_ENABLE_INT8", + "-DCK_USE_XDL", + "-DUSE_PROF_API=1", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + "-D__HIP_PLATFORM_HCC__=1"] + + cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] + + # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 + hip_version = get_hip_version() + if hip_version > Version('5.7.23302'): + cc_flag += ["-fno-offload-uniform-block"] + if hip_version > Version('6.1.40090'): + cc_flag += ["-mllvm", "-enable-post-misched=0"] + if hip_version > Version('6.2.41132'): + cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false"] + if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): + cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] + + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": cc_flag + generator_flag, + } + + include_dirs = [ + Path(this_dir) / "csrc" / "composable_kernel" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + ] + + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=renamed_sources, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) ) - ) def get_package_version(): diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py new file mode 100755 index 000000000..28f947beb --- /dev/null +++ b/tests/test_flash_attn_triton_amd.py @@ -0,0 +1,2162 @@ +import math +import os +import random + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import _get_block_size_n +from flash_attn.layers.rotary import apply_rotary_emb +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, DEBUG, is_rdna, get_arch + +MAX_HEADDIM_SM8x = 192 + + +is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) +is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 +is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) +is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) + + +def attn_bias_from_alibi_slopes( + slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None +): + batch, nheads = slopes.shape + device = slopes.device + slopes = rearrange(slopes, "b h -> b h 1 1") + if causal: + return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes + else: + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + relative_pos = torch.abs(row_idx + sk - sq - col_idx) + return -slopes * relative_pos.to(dtype=slopes.dtype) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_kvpacked_ref( + q, + kv, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + return attention_ref( + q, + kv[:, :, 0], + kv[:, :, 1], + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + key_leftpad=key_leftpad, + ) + + +def attention_qkvpacked_ref( + qkv, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + window_size=window_size, + softcap=softcap, + reorder_ops=reorder_ops, + ) + + +def generate_sparsity_mask(seqlen, sparsity=0.3): + repeats = seqlen // 16 // 2 + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), + # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) + nrow, ncol = seqlen // 16, seqlen // 256 + mask = torch.rand(nrow, ncol, device="cuda") < sparsity + return mask + + +def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + blockmask: (seqlen / 16, seqlen / 256) + attn_mask: (batch_size, seqlen) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen, seqlen) + Output: + output: (batch_size, seqlen, nheads, head_dim) + attention: softmax after dropout + """ + q, k, v = qkv.float().unbind(dim=2) + d = qkv.shape[-1] + seqlen = qkv.shape[1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) + blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") + blockmask = blockmask[:seqlen, :seqlen] + scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) + attention = torch.softmax(scores, dim=-1) + attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) + attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) + attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) + return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) + + +def convert_flash_attn_S_to_softmax( + S, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + head_dim, + is_dropout, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """FlashAttention stores the S matrix in a different way. + Arguments: + S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) + query_padding_mask: (batch_size, seqlen_q_rounded) + key_padding_mask: (batch_size, seqlen_k_rounded) + """ + if causal: + window_size = (window_size[0], 0) + seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] + S_converted = S + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + S.device, + ) + local_mask = F.pad( + local_mask, + (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), + value=True, + ) + S_converted = S_converted.masked_fill(local_mask, 0.0) + + # Need to zero out things not in attention_mask in case S was initialized with random values + # and some of those values aren't overwritten. + seqlen_q_og = ( + query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + ) + if query_padding_mask is not None: + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) + S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) + S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) + return S_converted[:, :, :seqlen_q, :seqlen_k] + + +def normalize_flash_attn_S( + attn_unnorm, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + is_dropout=False, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k, v: (batch_size, seqlen_k, nheads, head_dim) + key_padding_mask: (batch_size, seqlen_q) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + Output: + softmax_lse: (batch_size, nheads, seqlen_q) + softmax_max: (batch_size, nheads, seqlen_q) + """ + if causal: + window_size = (window_size[0], 0) + q, k, v = q.float(), k.float(), v.float() + _, seqlen_q, _, head_dim = q.shape + seqlen_k = k.shape[1] + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias.to(dtype=scores.dtype) + block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + scores_block = scores.split(block_size_n, dim=-1) + lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) + lse = torch.logsumexp(lse_block, dim=-1) + # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf + # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. + lse[lse == float("-inf")] = float("inf") + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) + attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) + attn_norm = torch.cat( + [ + a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") + for a, m in zip(attn_unnorm_block, cummax_block) + ], + dim=-1, + ) + if query_padding_mask is not None: + attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return attn_norm.to(dtype=attn_unnorm.dtype) + + +def get_dropout_fraction( + dropout_mask, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size +): + """ + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + """ + if causal: + window_size = (window_size[0], 0) + batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape + dropped = ~dropout_mask + valid = torch.ones_like(dropout_mask) + if query_padding_mask is not None: + dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) + if key_padding_mask is not None: + dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + dropout_mask.device, + ) + dropped.masked_fill_(local_mask, False) + valid.masked_fill_(local_mask, False) + dropped_total = dropped.sum() + return dropped.sum() / valid.sum() + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128]) +# @pytest.mark.parametrize("d", [256]) +# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) +# @pytest.mark.parametrize("seqlen", [97]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.17]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): + if USE_TRITON_ROCM: + if get_arch() == "gfx90a": + if seqlen == 97 and d == 256 and dropout_p == 0.17: + pytest.skip("This config doesnot work on MI200 Devices.") + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: + pytest.skip() # Reference implementation OOM + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) + qkv = torch.randn( + batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen, + seqlen, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S( + attn_unnorm, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + None, + None, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref( + qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_qkvpacked_ref( + qkv, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + # v = qkv[:, :, 2].float() + # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() + # if causal: + # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) + # qk.masked_fill_(causal_mask, float('-inf')) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # p_tmp = torch.softmax(qk / math.sqrt(d), -1) + # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values + # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values + # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values + # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values + # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:]) + # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:]) + # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:]) + # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :]) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + # do_o = (g.float() * out.float()).sum(-1) + # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) + # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + (dqkv,) = torch.autograd.grad(out, qkv, g) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if DEBUG: + print("out:", out, out.shape) + print("out_ref:", out_ref, out_ref.shape) + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if DEBUG: + print("dqkv:", dqkv, dqkv.shape) + print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) + print("dqkv_pt:", dqkv_pt, dqkv_pt.shape) + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) +# @pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +def test_flash_attn_varlen_qkvpacked( + seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype +): + if USE_TRITON_ROCM: + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: + pytest.skip() # Reference implementation OOM + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + nheads = 6 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) + qkv = torch.randn( + batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + + key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") + # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes( + alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal + ) + else: + alibi_slopes, attn_bias = None, None + + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True + ) + + out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( + qkv_unpad, + cu_seqlens, + max_seqlen, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen, + seqlen, + key_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S( + attn_unnorm, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref( + qkv, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_qkvpacked_ref( + qkv, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) + dqkv = dqkv_pad_fn(dqkv_unpad) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if DEBUG: + print("out:", out, out.shape) + print("out_ref:", out_ref, out_ref.shape) + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + + +# @pytest.mark.parametrize("kvpacked", [True, False]) +@pytest.mark.parametrize("kvpacked", [False]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [32]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("softcap", [0.0]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap +): + if USE_TRITON_ROCM: + if causal: + if seqlen_q ==1024 and seqlen_k==1024 and d==160: + pytest.skip("This test with causal=True is flakey") + + if softcap != 0.0: + pytest.skip("softcap not supported on AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + out, lse, S_dmask = flash_attn_kvpacked_func( + q, + kv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out:", out, out.shape) + print("lse:", lse, lse.shape) + print("S_dmask:", S_dmask, S_dmask.shape if S_dmask is not None else None) + + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + attn = normalize_flash_attn_S( + attn_unnorm, + q, + k_rep, + v_rep, + None, + None, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, None, None, causal=causal, window_size=window_size + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if kvpacked: + ( + dq, + dkv, + ) = torch.autograd.grad(out, (q, kv), g) + dk, dv = dkv.unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if DEBUG: + print("out:", out, out.shape) + print("out_ref:", out_ref, out_ref.shape) + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + if DEBUG: + print("attn:", attn, attn.shape) + print("attn_ref:", attn_ref, attn_ref.shape) + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + if DEBUG: + print("dropout_fraction:", dropout_fraction) + print("dropout_p:", dropout_p) + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if DEBUG: + print("dv:", dv, dv.shape) + print("dv_ref:", dv_ref, dv_ref.shape) + print("dv_pt:", dv_pt, dv_pt.shape) + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + + if DEBUG: + print("dk:", dk, dk.shape) + print("dk_ref:", dk_ref, dk_ref.shape) + print("dk_pt:", dk_pt, dk_pt.shape) + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + + if DEBUG: + print("dq:", dq, dq.shape) + print("dq_ref:", dq_ref, dq_ref.shape) + print("dq_pt:", dq_pt, dq_pt.shape) + assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + + + +@pytest.mark.parametrize("kvpacked", [False]) +# @pytest.mark.parametrize('kvpacked', [False]) +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('mha_type', ["mha"]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 147), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("softcap", [0.0]) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap +): + if USE_TRITON_ROCM: + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if softcap != 0.0: + pytest.skip("softcap not supported on AMD's Triton Backend yet") + + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap + + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes( + alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal + ) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + ( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + kv, + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if DEBUG: + print("out_unpad:", out_unpad, out_unpad.shape) + print("sm_lse:", sm_lse, sm_lse.shape) + + + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + attn = normalize_flash_attn_S( + attn_unnorm, + q, + k_rep, + v_rep, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_fraction = get_dropout_fraction( + dropout_mask, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, + ).item() + print(f"Actual dropout fraction: {dropout_fraction}") + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if dropout_p > 0.0: + print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") + print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") + + g = torch.randn_like(out) + if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): + if kvpacked: + ( + dq_unpad, + dkv_unpad, + ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) + dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq_unpad, + dk_unpad, + dv_unpad, + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + dq = dq_pad_fn(dq_unpad) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + if dropout_p > 0.0: + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate + if not alibi: + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + if DEBUG: + print("dv:", dv, dv.shape) + print("dv_ref:", dv_ref, dv_ref.shape) + print("dv_pt:", dv_pt, dv_pt.shape) + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + + if DEBUG: + print("dk:", dk, dk.shape) + print("dk_ref:", dk_ref, dk_ref.shape) + print("dk_pt:", dk_pt, dk_pt.shape) + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + + if DEBUG: + print("dq:", dq, dq.shape) + print("dq_ref:", dq_ref, dq_ref.shape) + print("dq_pt:", dq_pt, dq_pt.shape) + assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [32]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if USE_TRITON_ROCM: + if is_rdna(): + if seqlen_q == 1 and seqlen_k == 239 and d == 256: + pytest.skip("This config doesnot work on RDNA Devices.") + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 8 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + None, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + if DEBUG: + print("dv:", dv, dv.shape) + print("dv_ref:", dv_ref, dv_ref.shape) + print("dv_pt:", dv_pt, dv_pt.shape) + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + + if DEBUG: + print("dk:", dk, dk.shape) + print("dk_ref:", dk_ref, dk_ref.shape) + print("dk_pt:", dk_pt, dk_pt.shape) + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + + if DEBUG: + print("dq:", dq, dq.shape) + print("dq_ref:", dq_ref, dq_ref.shape) + print("dq_pt:", dq_pt, dq_pt.shape) + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("swap_sq_sk", [False, True]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [None]) +# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +def test_flash_attn_varlen_causal( + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype +): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 8 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + + if paged_kv_block_size is None: + k = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + block_table = None + else: + k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( + seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype + ) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad if paged_kv_block_size is None else k_cache_paged, + v_unpad if paged_kv_block_size is None else v_cache_paged, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + window_size=window_size, + block_table=block_table, + ) + out = output_pad_fn(out_unpad) + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + None, + 0.0, + None, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + None, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + test_backward = block_table is None + if test_backward: + ( + dq_unpad, + dk_unpad, + dv_unpad, + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + if test_backward: + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("num_splits", [1, 0]) +# @pytest.mark.parametrize("num_splits", [1]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("new_kv", [False, True]) +# @pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [None]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 1024), + (16, 128 * 1024), + (128, 128), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + has_batch_idx, + has_leftpad, + paged_kv_block_size, + rotary_fraction, + rotary_interleaved, + seqlen_new_eq_seqlen_q, + causal, + local, + alibi, + new_kv, + mha_type, + num_splits, + dtype, +): + if USE_TRITON_ROCM: + if paged_kv_block_size is not None: + pytest.skip("paged attention not supported on AMD's Triton Backend yet") + + if local == True: + pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + + if has_leftpad == True: + pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if has_batch_idx and paged_kv_block_size is not None: + pytest.skip() + if has_leftpad and paged_kv_block_size is not None: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + else: + k, v = None, None + if paged_kv_block_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + block_table = None + else: + ( + k_cache, + v_cache, + block_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes( + alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad + ) + else: + alibi_slopes, attn_bias = None, None + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out = flash_attn_with_kvcache( + q, + k_cache if paged_kv_block_size is None else k_cache_paged, + v_cache if paged_kv_block_size is None else v_cache_paged, + k, + v, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + block_table=block_table, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + alibi_slopes=alibi_slopes, + num_splits=num_splits, + ) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if paged_kv_block_size is None: + k_cache_select = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] + ) + v_cache_select = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] + ) + else: + k_cache_select = rearrange( + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache_select = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) + mult = 3 if not alibi else 5 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + + +def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): + num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks +