From 2d2dbe15593bc59f11fafe3b069f615786be608c Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 10 May 2024 10:48:42 -0500 Subject: [PATCH 01/20] Add Perf Kernels Add Perf Kernels This is a combination of 2 commits. Add Perf Kernels Add Perf Kernels This is a combination of 6 commits. add perf-kernels fix formating issues fix unused variables and other bugs fix other issues remove scripts save check changes format save save try pre-commit check save --- .github/workflows/amd_perf_kernel_tests.yml | 133 ++ .../03-matrix-multiplication-all-types.py | 377 ++++ .../03-matrix-multiplication-stream-k.py | 395 +++++ python/perf-kernels/06-attention-decode.py | 730 ++++++++ .../06-fused-attention-fwd-transV.py | 308 ++++ .../perf-kernels/06-fused-attention-transV.py | 928 ++++++++++ python/perf-kernels/README.md | 63 + python/perf-kernels/flash-attention.py | 1527 +++++++++++++++++ python/perf-kernels/hbm-bw-test.py | 200 +++ ...trix-multiplication-stream-k-oldversion.py | 485 ++++++ ...iplication-stream-k-singlekern-autotune.py | 563 ++++++ ...ultiplication-stream-k-singleloop-nomod.py | 387 +++++ 12 files changed, 6096 insertions(+) create mode 100644 .github/workflows/amd_perf_kernel_tests.yml create mode 100644 python/perf-kernels/03-matrix-multiplication-all-types.py create mode 100755 python/perf-kernels/03-matrix-multiplication-stream-k.py create mode 100644 python/perf-kernels/06-attention-decode.py create mode 100644 python/perf-kernels/06-fused-attention-fwd-transV.py create mode 100644 python/perf-kernels/06-fused-attention-transV.py create mode 100644 python/perf-kernels/README.md create mode 100644 python/perf-kernels/flash-attention.py create mode 100644 python/perf-kernels/hbm-bw-test.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py create mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py diff --git a/.github/workflows/amd_perf_kernel_tests.yml b/.github/workflows/amd_perf_kernel_tests.yml new file mode 100644 index 000000000000..07424924a832 --- /dev/null +++ b/.github/workflows/amd_perf_kernel_tests.yml @@ -0,0 +1,133 @@ +name: AMD Perf Kernel Tests + +on: + workflow_dispatch: + pull_request: + branches: [main_perf] + merge_group: + branches: [main_perf] + types: [checks_requested] + push: + branches: [main_perf] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }} + +permissions: read-all + +env: + TRITON_BUILD_WITH_CLANG_LLD: "TRUE" + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + TRITON_DISABLE_LINE_INFO: 1 + +jobs: + Check-File-Changes: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Check file changes + run: | + git fetch origin ${{ github.base_ref }} + changed_files=$(git diff --name-only origin/${{ github.base_ref }} ${{ github.sha }}) + echo "Changed files:" + echo "$changed_files" + if echo "$changed_files" | grep -v "^python/perf-kernels/"; then + echo "Changes detected outside of the python/perf-kernels directory. Failing the workflow." + exit 1 + fi + + Runner-Preparation-AMD: + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then + echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]' + else + echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' + fi + + pre-commit: + name: pre-commit (code formatting) + needs: Runner-Preparation-AMD + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + - name: Compute hash of pre-commit config + id: cache-key + run: | + echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT + shell: bash + - name: Cache pre-commit's cache dir + uses: actions/cache@v4 + with: + # Note that we cannot use environment variables here given there is + # no shell to interpret them in the paths. + path: | + ~/.cache/pre-commit + key: ${{ runner.os }}-${{ steps.cache-key.outputs.pre_commit_hash }} + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed + python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true + # If first run of yapf worked and made changes reset the tree to the original state + git reset --hard + python3 -m pre_commit run --all-files --verbose + - name: Print diff of changes if pre-commit failed + if: failure() + run: | + git diff + + Integration-Tests-AMD: + needs: Runner-Preparation-AMD + if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + container: + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Clear cache + run: | + rm -rf ~/.triton + mkdir -p ~/.triton + ls -alh ~/.triton + - name: Update PATH + run: | + echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install lit matplotlib pandas + - name: Install Triton + run: | + echo "PATH is '$PATH'" + pip uninstall -y triton + cd python + pip install -v -e . + - name: Run Perf Kernels Unit Tests + run: | + pytest -vvv ./python/perf-kernels/flash-attention.py + - name: Run Perf Kernels Benchmark + run: | + python ./python/perf-kernels/flash-attention.py diff --git a/python/perf-kernels/03-matrix-multiplication-all-types.py b/python/perf-kernels/03-matrix-multiplication-all-types.py new file mode 100644 index 000000000000..1b0676079ede --- /dev/null +++ b/python/perf-kernels/03-matrix-multiplication-all-types.py @@ -0,0 +1,377 @@ +import torch + +import triton +import triton.language as tl +import sys +import argparse +import pytest +import re + + +@triton.autotune( + configs=[ + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + ], + key=['M', 'N', 'K'], + use_cuda_graph=True, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(c_ptr.type.element_ty) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, c, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + # assert a.is_contiguous(), "Matrix A must be contiguous" + # assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=activation, + ) + + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8e4': tl.float8e4b8, + 'fp8e5': tl.float8e5b16, +} + + +def gen_input(M, N, ty_name, needTrans, seed, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + if needTrans: + raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T + else: + raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') + # avoid type conversion rounding errors of subnormal values + raw_data += 0.1 + if d_type == tl.float8e4b8: + raw_data += torch.sign(raw_data) + + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., rocBLAS). +def get_x_vals(): + x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)] + + x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)] + + return x_vals + + +@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype, col_a, col_b", [ + (*shape, in_dtype, out_dtype, col_a, col_b) + for shape in get_x_vals() + for in_dtype, out_dtype in [('fp16', 'fp16'), ('bf16', 'bf16'), ('fp16', + 'fp32'), ('fp32', + 'fp32'), ('fp8e4', + 'fp16'), ('fp8e5', 'fp16'), + #('int8', 'int8'), + ('int8', 'int32')] + # Only test k-major tensors because + # 1. This is the most preformant config and the current focus + # 2. Other case does not work with num_stages=0 (TODO (zhanglx)) + for col_a in [True, False] + for col_b in [True, False] +]) +def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype): + a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda') + b, b_fp16 = gen_input(K, N, in_dtype, col_b, 2, device='cuda') + # Allocates output. + tl_out_dtype = name_to_tl_types[out_dtype] + torch_out_dtype = tl_to_torch_types[tl_out_dtype] + c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype) + matmul(a, b, c, activation="") + if in_dtype == 'fp8e4' or in_dtype == 'fp8e5' or in_dtype == 'int8': + # For f8 and int8 inputs, use fp16 for torch.matmul + torch_output = torch.matmul(a_fp16, b_fp16) + else: + torch_output = torch.matmul(a, b) + #print(f"triton_output={c}") + #print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + if in_dtype == 'int8': + torch.testing.assert_close(c.to(torch.float16), torch_output, atol=1e-3, rtol=rtol) + else: + torch.testing.assert_close(c, torch_output.to(torch_out_dtype), atol=5e-3, rtol=rtol) + + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of rocBLAS. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + + +def get_type(provider): + res = re.findall(r'\(.*?\)', provider) + return res[0][1:-1] + + +inout_dtype = { + 'int8': torch.int8, + 'fp16': torch.float16, + 'fp32': torch.float32, + 'bf16': torch.bfloat16, + 'fp8e4': torch.float16, + 'fp8e5': torch.float16, +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot + x_vals=get_x_vals(), + line_arg='provider', # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=[ + 'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)', + 'triton(fp8e5)' + ], + # Label name for the lines + line_names=[ + "rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5" + ], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + in_dtype = get_type(provider) + out_dtype = inout_dtype[in_dtype] + + quantiles = [0.5, 0.2, 0.8] + if 'rocblas' in provider: + a = torch.randn((M, K), dtype=tl_to_torch_types[name_to_tl_types[in_dtype]], device='cuda') + b = torch.randn((K, N), dtype=tl_to_torch_types[name_to_tl_types[in_dtype]], device='cuda') + + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + else: # triton, different data types + assert "triton" in provider + a, _ = gen_input(M, K, in_dtype, False, 1, device='cuda') + b, _ = gen_input(K, N, in_dtype, True, 2, device='cuda') + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=out_dtype) + + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, activation=""), quantiles=quantiles) + global verbose + if verbose: + print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="GEMM tutorial example", + allow_abbrev=False, + ) + + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") + args = parser.parse_args() + + return args + + +def main(): + # assign to a global verbose var to indicate whether print + # best tuning config + global verbose + args = parse_args() + verbose = args.v + benchmark.run(show_plots=True, print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py new file mode 100755 index 000000000000..62d820719b9a --- /dev/null +++ b/python/perf-kernels/03-matrix-multiplication-stream-k.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def streamk_gemm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + total_tiles_streamk, + total_programs_streamk, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + + # Determine whether we are in the first wave or full_tiles phase based on pid + is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 + + # Calculate starting and ending iterations for first wave + if not is_first_wave: + tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers +# rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) +# rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, acc) + else: + # start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 4 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{total_full_tiles_streamk=}") + print(f"{total_partial_tiles_streamk=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + # allocates locks to sync work accross SMs + grids = total_programs_streamk + total_blocking_tiles + kk = streamk_gemm[(grids, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + total_tiles_streamk=total_tiles_streamk, + total_programs_streamk=total_programs_streamk, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + ) + if matmul._debug: + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + # print(kk.asm['ttgir']) + # print(kk.asm['amdgcn']) + + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 4864, 4096, 8256 # some problem size to test +#m, n, k = 4096, 4096, 8192 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, + kpack) +#exit(0) +matmul.set_debug(False) +expected = A @ B + +#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" +print("pass validation test") + +# for debugging, uncomment the following line +# exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, + waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLK_M) * (n // BLK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench( + lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/06-attention-decode.py b/python/perf-kernels/06-attention-decode.py new file mode 100644 index 000000000000..3f38e5031eca --- /dev/null +++ b/python/perf-kernels/06-attention-decode.py @@ -0,0 +1,730 @@ +from typing import Optional +import pytest +import torch +import sys + +import triton +import triton.language as tl + + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p, v) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (0, BLOCK_N)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load(K_scale_shift_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v_scale_shift = tl.load(V_scale_shift_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } + SUPPORTED_MAX_K = 128 + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty([B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device) + metadata = torch.empty([B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + num_warps = 1 + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + # print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.empty((B, H, G, M, Kq), device=q.device, dtype=q.dtype).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + grid = (B * G * H, M, k_block_num) + _splitK_reduce[grid]( + o_splitk, metadata, out, lse, **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), M_ceil=M_ceil, BLOCK_SIZE=k_block_size, G=G, H=H, + # TODO: Tune num_warps + split_k=split_k, splitK_pow2=splitK_pow2, use_mask=use_mask, num_warps=4) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + + +attention = _attention.apply + + +def get_input_shapes(): + cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] + + return cases + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=0., + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + scale = 1 / K**0.5 + tri_out = attention(q, k, v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, + device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) + k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, + device="cuda").normal_(mean=1.0, + std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) + quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) + scale = 1 / K**0.5 + tri_out = attention(q, quant_k, quant_v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + FLASH_VER = 2 +except BaseException: + try: + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append( + triton.testing.Benchmark( + x_names=['B', 'Mq', 'Mkv', 'Hq', 'Hkv', 'K'], x_vals=get_input_shapes(), line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), styles=[('red', '-'), + ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, 'mode': mode, 'causal': causal + })) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn([B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False) + k = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn([B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, + requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + # flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + # total_flops = 2 * flops_per_matmul + # totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + # return totalBytes / ms * 1e-9 + return ms * 1000 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/06-fused-attention-fwd-transV.py b/python/perf-kernels/06-fused-attention-fwd-transV.py new file mode 100644 index 000000000000..53517a395c8d --- /dev/null +++ b/python/perf-kernels/06-fused-attention-fwd-transV.py @@ -0,0 +1,308 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import pytest +import torch +import sys + +import triton +import triton.language as tl + +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8: tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + sm_scale, + M, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(0, 1)) + # initialize offsets + # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + if torch.version.hip is None: + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 if Lk <= 64 else 8 + + ## hardcoded best perf_configs for MI250 + if Lk == 64: + ## D_HEAD = 64 + BLOCK_M = 128 + BLOCK_N = 64 + waves_per_eu = 3 + num_warps = 4 + num_stages = 1 + ## causal=False likes to pre load v but causal=True does not + pre_load_v = False if causal else True + slice_k_tile = 32 + kpack = 1 + else: + ## D_HEAD = 128 + ## For fp16, pick BLOCK_M=256, num_warps=8 + ## For fp8, pick BLOCK_M=128, num_warps=4 + ## TODO (zhanglx): add tuning infra for FA + BLOCK_M = 128 if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256 + BLOCK_N = 128 + waves_per_eu = 2 + num_warps = BLOCK_M // 32 + num_stages = 1 + pre_load_v = False + slice_k_tile = 32 + kpack = 1 + + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, + k, + v, + sm_scale, + M, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + waves_per_eu=waves_per_eu, + num_warps=num_warps, + num_stages=num_stages, + pre_load_v=pre_load_v, + slice_k_tile=slice_k_tile, + kpack=kpack, + ) + + return o + + +attention = _attention.apply + +name_to_torch_types = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp8': float8} + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', + [(*shape, dtype) + for shape in [(4, 48, 1024, 128), (4, 48, 2048, 128), (4, 48, 4096, 128)] + for dtype in ['fp16', 'bf16', 'fp8']]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, dtype): + torch.manual_seed(20) + init_dtype = torch.float16 if dtype == 'fp8' else name_to_torch_types[dtype] + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + v = (torch.empty((Z, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + sm_scale = 0.5 + # reference implementation + # M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = torch.softmax(p.float(), dim=-1).to(q.dtype) + ref_out = torch.matmul(p, v.transpose(2, 3)) + # triton implementation + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + # dout = torch.randn_like(q, dtype=torch.float16) + tri_out = attention(q, k, v, sm_scale) + # compare + atol = 1.4e-1 if dtype == 'fp8' else 1e-2 + rtol = 1e-2 if dtype == 'fp8' else 3e-3 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) + + +try: + FLASH_VER = 2 +except BaseException: + try: + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +# vary seq length for fixed head and batch=4 +configs = [] +for dtype in ['fp16', 'bf16', 'fp8']: + for D_HEAD in [128]: + for causal in [False]: + configs.append( + triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], x_vals=[ + (16, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], line_arg='provider', line_vals=['triton'], line_names=['Triton'], + #styles=[('red', '-'), ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-fwd-d{D_HEAD}-causal={causal}-{dtype}', + args={'D_HEAD': D_HEAD, 'dtype': dtype, 'causal': causal})) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, provider, dtype, device="cuda"): + if dtype == 'fp8' and not TORCH_HAS_FP8E4: + sys.exit("fp8 is not available") + warmup = 25 + rep = 100 + init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + return total_flops / ms * 1e-9 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/06-fused-attention-transV.py b/python/perf-kernels/06-fused-attention-transV.py new file mode 100644 index 000000000000..60113d3aa17d --- /dev/null +++ b/python/perf-kernels/06-fused-attention-transV.py @@ -0,0 +1,928 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import pytest +import torch + +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype: tl.constexpr = torch.float8_e5m2fnuz + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # causal = False + else: + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 0, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 32, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': True}, + num_stages=1, num_warps=4), # d64-False + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'slice_k_tile': 64, 'pre_load_v': False}, + num_stages=1, num_warps=4), # d64-True + ], + key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], +) +@triton.jit +def _attn_fwd( + Q, + K, + V, + sm_scale, + M, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(0, 1)) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + pre_load_v, + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compiler to schedule the + # two loops independently + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + 2, + offs_m, + offs_n, + N_CTX, + pre_load_v, + ) + # epilogue + # write back m + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _bwd_preprocess( + Out, + DO, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + P_SEQ, + num_block_q, + num_block_kv, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + qk_scale = sm_scale * 1.44269504 + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_kz + off_h * stride_kh + DV += off_z * stride_vz + off_h * stride_vh + # See fwd pass above for explanation. + qk_scale = sm_scale * 1.44269504 + for start_n in range(0, num_block_kv): + if CAUSAL: + lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0) + else: + lo = 0 + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[None, :] * stride_qm + offs_k[:, None] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dk amd dv + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + if CAUSAL: + qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk * qk_scale - l_i[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(Q.dtype.element_ty), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + tl.store(dk_ptrs, dk) + tl.store(dv_ptrs, dv) + + +@triton.jit +def _bwd_kernel_dk_dv( + Q, + K, + V, + sm_scale, + Out, + DO, + DK, + DV, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # Q is consumed depending on block ID. Every block uses + # previous block offset by BLOCK_M x D_HEAD. + qvk_offset = off_hz * stride_qh + qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # offs_d = tl.arange(0, BLOCK_DMODEL) + # Initialize pointers to Q, K, V + Q_block_ptr = tl.make_block_ptr(base=Q + qdo_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, start_m * BLOCK_M), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_vn, stride_vk), + offsets=(0, start_m * BLOCK_M), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO + qdo_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0)) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + qk_scale = sm_scale * 1.44269504 + # load k and v: they will stay in SRAM throughout + k = tl.load(K_block_ptr) + k = (k * qk_scale).to(k.dtype) + v = tl.load(V_block_ptr) + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # This lower loop bound is because of the causal mask. We create a lower triangular + # result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can + # be ignored in the GEMM. + lo = start_m * BLOCK_M + hi = N_CTX + # loop over q, do + for start_n in range(lo, hi, BLOCK_N): + offs_m_curr = offs_n[:, None] + start_n + # -- load q, do -- + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + # -- compute qk ---- + qk = tl.dot(q, k) + qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf")) + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i) + # -- compute dv ---- + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp + # compute dk + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # update pointers + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0)) + # initialize pointers to output + DK_block_ptr = tl.make_block_ptr(base=DK + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_kn, stride_kk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + DV_block_ptr = tl.make_block_ptr(base=DV + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptr, (dk * sm_scale).to(k.dtype)) + tl.store(DV_block_ptr, dv.to(v.dtype)) + + +@triton.jit +def _bwd_kernel_dq( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + L, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # offs_d = tl.arange(0, BLOCK_DMODEL) + # Initialize pointers to Q, K, V + Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_vn, stride_vk), + offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + qk_scale = sm_scale * 1.44269504 + # load q and do: they will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) + do = tl.load(DO_block_ptr) + Di = tl.load(D_ptrs + offs_m) + l_i = tl.load(l_ptrs + offs_m) + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # loop over k, v + lo = 0 + hi = (start_m + 1) * BLOCK_M + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk ---- + qk = tl.dot(q, k) + qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf")) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp + # compute dq. Unfortunately we cannot avoid transpose here as this loop + # uses k both normal and transpose. + dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k)) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) + # initialize pointers to output + DQ_block_ptr = tl.make_block_ptr(base=DQ + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + tl.store(DQ_block_ptr, (dq * sm_scale).to(q.dtype)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + if torch.version.hip is None: + # BLOCK_M = 128 + # BLOCK_N = 64 if Lk <= 64 else 32 + # num_stages = 4 if Lk <= 64 else 3 + # num_warps = 4 if Lk <= 64 else 8 + pass + + stage = 3 if causal else 1 + grid = lambda META: (triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, + k, + v, + sm_scale, + M, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + best_config = _attn_fwd.get_best_config() + block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.split_kernel = split_kernel + return o + + @staticmethod + def backward(ctx, do): + # configuration is not supported + assert (not (ctx.split_kernel and not ctx.causal)) + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, L = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + do_scaled = torch.empty_like(do) + # Figure out what BLOCK size fwd used and adjust num_blocks accordingly. + # If the two are the same, we don't need this but the bwd pass block size + # is smaller than the fwd so we need this scaling to ensure we loop over all + # values and don't skip some blocks. + # Alternatively we could compute a new grid but this keeps it consistent + # with fwd and easier to reason about. + block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, + do, + do_scaled, + delta, + BLOCK_M=block_scale * BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + if not ctx.split_kernel: + _bwd_kernel[(ctx.grid[1], )]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + block_scale * ctx.grid[0], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + CAUSAL=ctx.causal, + num_stages=1, + ) + else: + dq = torch.zeros_like(q) + _bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dk, + dv, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + num_stages=1, + ) + _bwd_kernel_dq[ctx.grid]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + L, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=2 * BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=4, + waves_per_eu=1, + num_stages=1, + ) + # print(h.asm["ttgir"]) + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128), + #(4, 48, 8192, 64), + #(4, 48, 16384, 64) +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + v = (torch.empty((Z, H, D_HEAD, N_CTX), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()) + sm_scale = 0.5 + # dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(v.dtype) + ref_out = torch.matmul(p, v.transpose(2, 3)) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), +]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + causal = True + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0, 5 + split_kernel = True + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(v.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = attention(q, k, v, causal, sm_scale, split_kernel) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + if torch.version.hip is None: + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + else: + assert torch.allclose(ref_dv, tri_dv, atol=5e-2, rtol=0) + assert torch.allclose(ref_dk, tri_dk, atol=5e-2, rtol=0) + assert torch.allclose(ref_dq, tri_dq, atol=5e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +name_to_torch_types = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, +} + +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ['fwd']: + for dtype in ["fp16", "bf16"]: + for D_HEAD in [128, 64]: + for causal in [False, True]: + configs.append( + triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], x_vals=[ + (16, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), styles=[('red', '-'), + ('blue', '-')], + ylabel='ms', plot_name=f'fused-attention-d{D_HEAD}-{mode}-causal={causal}-{dtype}', + args={'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal})) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + init_dtype = name_to_torch_types[dtype] + split_kernel = False + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + split_kernel = True + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=init_dtype, device=device, requires_grad=True) + if FLASH_VER == 1: + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) + elif FLASH_VER == 2: + fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f'unknown {FLASH_VER = }') + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == 'bwd': + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops / ms * 1e-9 + + +# only works on post-Ampere GPUs right now +bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md new file mode 100644 index 000000000000..5bcedbf49cdd --- /dev/null +++ b/python/perf-kernels/README.md @@ -0,0 +1,63 @@ +# AMD Perf Kernels + +This directory contains customized/tuned/experimental kernels for AMD Instinct series GPUs. +Please make sure your Triton compiler is v2.1 or later, and is from the OpenAI Triton repository +[here](https://github.com/openai/triton). To install Triton, please see +[these](https://github.com/openai/triton/tree/main?tab=readme-ov-file#install-from-source) instructions. + +## `06-fused-attention-transV.py` + +This script is a copy of `tutorials/06-fused-attention.py` with the following +two changes: + +- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the +fastest changing (a.k.a. leading or least strided) dimension. +This script produces better performance than `tutorials/06-fused-attention.py` +since it has better LDS access efficiency for tensor V. +Note that in the future, we'll improve the LDS access efficiency for +non-transposed tensor V, i.e. head dimension is the fastest changing dimension. +- Only fwd kernel is benchmarked. + +## `06-fused-attention-fwd-transV.py` + +This script is used to produce the best performance for fwd kernel. +It is a copy of `06-fused-attention-transV.py` with the following +changes: + +- All bwd kernels are removed. +- Storing `m` at the end of the fwd kernel is removed. +- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned +on MI250X and hard coded. + +Note that this script is also used to benchmark FA performance with 2 GCDs. +Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details. + +## `flash-attention.py` + +This script contains the Flash Attention kernel with the following support + +- Arbitrary Q and KV sequence lengths, and arbitrary head sizes +- Autoregressive or "causal" masking +- Flash Attention v2 with variable sequence lengths +- Multi and Grouped Query attention +- ALiBi bias +- Matrix bias + +These are currently supported for the forward kernel only. + +## `06-attention-decode.py` + +This contains the Flash Decoding kernel. + +## `hbm-bw-test.py` + +This is a script that measures HBM bandwidth performance on your device. + +## `03-matrix-multiplication-all-types.py` + +This script contains the GEMM kernel that supports int8, int32, fp16, +fp32, bf16 and f8 (both e5m2 and e4m3) datatypes. + +## `03-matrix-multiplication-stream-k.py` + +This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py new file mode 100644 index 000000000000..6fc861b281fa --- /dev/null +++ b/python/perf-kernels/flash-attention.py @@ -0,0 +1,1527 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import argparse +import pytest +import sys +import torch + +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype: tl.constexpr = torch.float8_e5m2fnuz + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + dropout_p, return_encoded_softmax = 0.0, False + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(self, dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + if self.varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def print_gpu(prefix, val=None): + if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): + if val is not None: + tl.device_print(prefix, val) + else: + tl.device_print(prefix) + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, + philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, + masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + use_cuda_graph=True, +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + alibi_slopes, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + BATCH_SIZE: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + # need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + # need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + # need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr(base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1)) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) + if BIAS_TYPE != 0: + b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs + bias_ptr = tl.make_block_ptr( + base=bias + b_offset, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr(base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0)) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, + alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + # This is a > check because mask being 0 blocks the store. + l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +@triton.jit +def _attn_bwd_preprocess( + Out, + DO, + Delta, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_doz, + stride_doh, + stride_dom, + stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + + +@triton.jit +def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + # offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + # offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + + DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, o, metadata): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + metadata.check_args(q, k, v, o) + if metadata.varlen: + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = metadata.num_contexts + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), + metadata.bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, + dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, + BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o, encoded_softmax + + @staticmethod + def backward(ctx, do, _): + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + ctx.alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): + torch.manual_seed(20) + + # Initialize q, k, v + q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + max_seqlens_q = torch.max(seqlens_q).item() + max_seqlens_k = torch.max(seqlens_k).item() + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + # -1 because the last entry of cu_seqlens_q specifies the end of the last seq + # num_ctxs = len(cu_seqlens_q) - 1 + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + if TORCH_HAS_FP8E5: + q = q.to(torch_dtype) + k = k.to(torch_dtype) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 24, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 16, 15498, 2, 128), + (2, 16, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('use_bias', [True]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): + pytest.skip() + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + if TORCH_HAS_FP8E5: + q = q.to(torch_dtype) + k = k.to(torch_dtype) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), + (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), + (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), + (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + pytest.skip() + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), +]) +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('use_alibi', [False, True]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, + dtype=torch.float16): + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + if causal and seqlen_q != seqlen_k: + pytest.skip() + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + + dropout_p = 0 + q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + dout = torch.randn_like(q) + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # test + #print("reference") + #print(ref_dv) + #print("tri") + #print(tri_dv) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +def nonvarlen_benchmark_configs(): + configs = [ + (16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + (2, 48, 48, 2048, 1024), + (2, 48, 48, 4096, 8192), + (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + (8, 16, 16, 1989, 15344), + (4, 16, 16, 4097, 163), + (2, 16, 16, 8122, 2159), + (1, 16, 16, 16281, 7), + (2, 48, 48, 1021, 1020), + (2, 48, 48, 2001, 2048), + (2, 48, 48, 3996, 9639), + (2, 48, 48, 8181, 1021), + ] + return configs + + +def varlen_benchmark_configs(): + configs = [ + (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + (4, 16, 8, 4096, 4096), + (2, 16, 4, 8192, 8192), + (2, 16, 8, 16384, 16384), + (2, 48, 12, 1024, 1024), + (2, 48, 24, 2048, 2048), + (2, 48, 8, 4096, 4096), + (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + (2, 64, 32, 1024, 1024), + (4, 64, 16, 2048, 2048), + (4, 64, 8, 4096, 4096), + (4, 64, 32, 8192, 8192), + (4, 128, 16, 16384, 16384), + ] + return configs + + +def run_benchmark(custom): + + args = parse_args() + dtype = arg_to_torch_dtype[args.dtype] + # hk = args.hq if not args.hk else args.hk + # sk = args.sq if not args.sk else args.sk + head_size = 128 if not args.d else args.d + mode = 'fwd' + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + causal = args.causal + varlen = args.varlen + configs = [] + if custom: + x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + else: + if varlen: + x_vals_list = varlen_benchmark_configs() + else: + x_vals_list = nonvarlen_benchmark_configs() + print_time = args.return_time + line_names = 'Time (ms)' if print_time else 'TFLOPS' + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], + line_names=[line_names], styles=[('red', '-')], ylabel='ms', + plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', + args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + # TODO: Enable bias after testing. + # if use_bias: + # bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) + # else: + # bias = None + # bias = None + + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + + flops_per_matmul = 0 + if varlen: + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + for i in range(0, input_metadata.num_contexts): + seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] + seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] + # x2 for 2 GEMMs + flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + if causal: + input_metadata.need_causal() + o = torch.empty_like(q) + fn = lambda: attention(q, k, v, o, input_metadata) + if mode == 'bwd': + o, _ = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + total_flops = 2 * flops_per_matmul + # TODO: This needs to be fixed for unequal Q/K seqlens + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + if print_time: + return ms + else: + return total_flops / ms * 1e-9 + + bench_flash_attention.run(save_path=".", print_data=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action='store_true', default=False) + parser.add_argument("-varlen", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False) + return parser.parse_args() + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + custom_config = True + assert args.b and args.hq and args.sq and args.d, \ + "If custom config is specified, please provide \ + all of batch, number of Q heads, Q sequence length \ + and head size." + + assert args.dtype in arg_to_torch_dtype, \ + "Only fp16, bf16 and f32 types currently supported." + + run_benchmark(custom_config) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/hbm-bw-test.py b/python/perf-kernels/hbm-bw-test.py new file mode 100644 index 000000000000..a20ce044eaee --- /dev/null +++ b/python/perf-kernels/hbm-bw-test.py @@ -0,0 +1,200 @@ +""" +Simple test to measure achieved HBM bandwidth. +This kernel moves N bytes of data from one region in HBM to another, using Triton. +""" + +# %% +# Compute Kernel +# -------------- + +import argparse +import sys +import torch + +import triton +import triton.language as tl + + +@triton.jit +def copy_kernel( + input_ptr, # *Pointer* to input vector. + output_ptr, # *Pointer* to output vector. + NUM_ELEMENTS: tl.constexpr, # Total elements to move. + BLOCK_SIZE: tl.constexpr, # Elements to load / store per iteration + VECTOR_SIZE: tl.constexpr, # Size of the entire vector being moved. + READ_ONLY: tl.constexpr, +): + pid = tl.program_id(axis=0) + # Offset at which to start for this WG. + lo = pid * NUM_ELEMENTS + # Offset until which to read for this WG. + hi = lo + NUM_ELEMENTS + # NUM_ITERS: tl.constexpr = triton.cdiv(NUM_ELEMENTS, BLOCK_SIZE) + IRREGULAR_SIZE: tl.constexpr = NUM_ELEMENTS % BLOCK_SIZE + acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if IRREGULAR_SIZE: + hi = hi - IRREGULAR_SIZE + # Move buffer in chunks of block_size + for idx in range(lo, hi, BLOCK_SIZE): + offsets = idx + tl.arange(0, BLOCK_SIZE) + in_vals = tl.load(input_ptr + offsets) + acc += in_vals + if not READ_ONLY: + tl.store(output_ptr + offsets, in_vals) + # Unroll last irregular iter in case the total sized moved by this WG + # is not a multiple of block size. + if IRREGULAR_SIZE: + lo = hi + hi = hi + IRREGULAR_SIZE + offsets = lo + tl.arange(0, BLOCK_SIZE) + mask = offsets < hi + in_vals = tl.load(input_ptr + offsets, mask=mask) + if not READ_ONLY: + tl.store(output_ptr + offsets, in_vals, mask=mask) + + if READ_ONLY: + tl.store(output_ptr + tl.arange(0, BLOCK_SIZE), acc) + + +def copy(src: torch.Tensor, block_size, wgs, dst: torch.Tensor): + assert src.is_cuda + vector_size = src.numel() + assert dst.numel() == vector_size or dst.numel() == block_size + size_per_wg = vector_size / wgs + assert size_per_wg >= block_size, \ + "Too many WGS. Please increase the size of the buffer using -size." \ + f" We want a buffer of size {wgs * block_size} f32 elements or larger." + grid = (wgs, 1, 1) + # Each WG will move these many elements + n_elements = triton.cdiv(vector_size, wgs) + # If we want to read only, we do a dummy write of a single block size back to HBM + read_only = dst.numel() != src.numel() + copy_kernel[grid]( + src, + dst, + NUM_ELEMENTS=n_elements, + BLOCK_SIZE=block_size, + VECTOR_SIZE=vector_size, + READ_ONLY=read_only, + num_warps=4, + ) + + +def get_reference(x, wgs, gbps): + ms = triton.testing.do_bench(lambda: torch.clone(x)) + bw = gbps(ms) + triton_output = torch.empty_like(x) + copy(x, block_size=16384, wgs=wgs, dst=triton_output) + err = triton_output - x + if torch.count_nonzero(err): + assert False, f"Torch and Triton do not match - max error is "\ + f"{torch.max(torch.abs(err))}" + return bw + + +def align_size_to_wgs(size, wgs): + return (size // wgs) * wgs + + +def run_benchmark_suite(vector_size, block_size, num_cores, read_only): + configs = [] + # Define WGs in powers of 2 from 1 - 2048. + x_vals = [(2**i) for i in range(0, 12)] + num_cu_aligned_wgs = [(num_cores * i) for i in range(1, 5)] + import bisect + for i in num_cu_aligned_wgs: + bisect.insort(x_vals, i) + configs.append( + triton.testing.Benchmark( + x_names=['wgs'], # Argument names to use as an x-axis for the plot. + x_vals=x_vals, x_log=True, # x axis is logarithmic. + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=['triton'], # Possible values for `line_arg`. + line_names=['Triton'], # Label name for the lines. + styles=[('blue', '-'), ('green', '-')], # Line styles. + ylabel='GiB/s', # Label name for the y-axis. + plot_name=f'size={vector_size}', # Name for the plot. Used also as a file name for saving the plot. + args={'size': vector_size}, # Values for function arguments not in `x_names` and `y_name`. + )) + + @triton.testing.perf_report(configs) + def benchmark(size, provider, wgs): + aligned_size = align_size_to_wgs(size, wgs) + src_tensor = torch.randn(aligned_size, device='cuda') + dst_tensor = torch.empty(block_size, device='cuda') + if not read_only: + dst_tensor = torch.empty_like(src_tensor) + ms = triton.testing.do_bench(lambda: copy(src_tensor, block_size, wgs, dst_tensor)) + # 8 because 4 bytes from load, 4 from store. + if read_only: + gbps = lambda ms: 4 * size / ms * 1e3 / 1024**3 + else: + gbps = lambda ms: 8 * size / ms * 1e3 / 1024**3 + return gbps(ms) + + benchmark.run(print_data=True, show_plots=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="HBM Bandwidth Benchmark", + allow_abbrev=False, + ) + parser.add_argument("-direction", type=str, default="read-only", + help="Data movement direction: read-only, read-write") + parser.add_argument("-size", type=int, default=1024, help="Size of buffer moved, in MiB") + parser.add_argument("-num_wgs", type=int, default=0, help="Number of workgroups to use") + parser.add_argument("-block_size", type=int, default=16384, help="Block size per iteration to load / store") + parser.add_argument("-run_sweep", action='store_true', default=False, help="Run sweep of B/W vs workgroups") + return parser.parse_args() + + +def main(): + args = parse_args() + torch.manual_seed(0) + num_cores = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + size = args.size + rw = args.direction == "read_write" + num_elements = size * 1024 * 1024 // 4 + if args.run_sweep: + assert args.num_wgs == 0, "If running the benchmark suite, please do not specify the number of WGs to use." + run_benchmark_suite(num_elements, args.block_size, num_cores, not rw) + return + if args.num_wgs == 0: + # num_wgs not user specified - get from device properties + num_wgs = num_cores + print(f"Using {num_wgs} workgroups. It is recommended to "\ + "use -num_wgs to provide this number.") + else: + assert args.num_wgs > 0, "Please provide a positive, non-zero number of workgroups!" + num_wgs = args.num_wgs + if num_wgs % num_cores: + print(f"Note! Your device has {num_cores} cores. It is recommended to use"\ + " a number for workgroups that is a multiple of this number."\ + f" You have currently chosen {num_wgs}.") + num_elements_rounded = align_size_to_wgs(num_elements, num_wgs) + if num_elements != num_elements_rounded: + print(f"Removing last {num_elements - num_elements_rounded} elements to "\ + "get a tensor size aligned to multiple of number of workgroups.") + num_elements = num_elements_rounded + src_tensor = torch.randn(num_elements, device="cuda") + if rw: + # 8 because 4B for read. 4B for write. + gbps = lambda ms: 8 * num_elements / ms * 1e3 / 1024**3 + ref_bw = get_reference(src_tensor, num_wgs, gbps) + print(f"Reference PyTorch bandwidth = {ref_bw} GiB/s") + else: + gbps = lambda ms: 4 * num_elements / ms * 1e3 / 1024**3 + if size < 1024: + print("Note! It is recommended to use a buffer larger than 1 GiB.") + if num_elements % args.block_size: + print("Note! This config is suboptimal. It is recommended to use a buffer that"\ + f" is a multiple of wgs x block size = {num_wgs * args.block_size} elements.") + dst_tensor = torch.empty_like(src_tensor) if rw else torch.empty(args.block_size, device='cuda') + triton_ms = triton.testing.do_bench(lambda: copy(src_tensor, args.block_size, num_wgs, dst=dst_tensor), warmup=1, + rep=1) + print(f"Triton bandwidth = {gbps(triton_ms)} GiB/s") + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py new file mode 100644 index 000000000000..beb8b0df9b1f --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py @@ -0,0 +1,485 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +# iterate, multiply and accumulate over K axis +@triton.jit() +def mac_loop( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + tile_id, + mod1, + mod2, + iters_per_tile, + start_iter, + end_iter, + pid_m, + pid_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + + # where are we in the grid + # tile_id = start_iter // iters_per_tile + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (start_iter % iters_per_tile) + # B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (start_iter % iters_per_tile) + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (mod1) + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (mod1) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + #if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM + + +# if mod2 == 0:# last iteration of the tile always happens before its start on another SM +# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! +# tl.store(C_, acc) +# if start_iter % iters_per_tile != 0: # only if tile has been partially processed +# if mod1 != 0: # only if tile has been partially processed +# tl.atomic_xchg(locks + tile_id, 1) +# else: +# while tl.atomic_cas(locks + tile_id, 1, 1) != 1: +# pass +# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! +# tl.atomic_add(C_, acc) + if mod1 == 0 and mod2 == 0: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + +@triton.jit() +def first_wave( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + + while start_iter < last_iter: + end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) + mod1 = start_iter % iters_per_tile + mod2 = end_iter % iters_per_tile + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + mac_loop( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + tile_id, + mod1, + mod2, + iters_per_tile, + start_iter, + end_iter, + pid_m, + pid_n, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + ) + + start_iter = end_iter + + +# similar to the reference matmul kernel +@triton.jit() +def full_tiles( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C, acc) + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = False + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 8 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + # allocates locks to sync work accross SMs + k1 = first_wave[(total_programs_streamk, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k1.n_regs} registers used, {k1.n_spills} spills") + k2 = full_tiles[(total_blocking_tiles, )]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k2.n_regs} registers used, {k2.n_spills} spills") + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +m, n, k = 8192, 8192, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +BLK_M = 128 +BLK_N = 256 +BLK_K = 16 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" + +# for debugging, uncomment the following line +# exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench( + lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // 128) * (n // 128) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py new file mode 100644 index 000000000000..a35d691a0225 --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py @@ -0,0 +1,563 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") +# global flag to indicate whether using the full tuing space +tuning_full_space = True +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def get_tile_config(M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, two_tiles, + total_programs_streamk): + total_blocks_M = tl.cdiv(M, BLOCK_M) + total_blocks_N = tl.cdiv(N, BLOCK_N) + iters_per_tile = tl.cdiv(K, BLOCK_K) + # GROUP_M = 0 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + return iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk + + +# pruned some unreasonable config +def prune_configs(configs, named_args): + # call only for full tuning space + if not tuning_full_space: + return configs + + SIZE_M = named_args["A"].shape[0] + SIZE_N = named_args["B"].shape[1] + # SIZE_K = named_args["A"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, _ =\ + kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_K"] + if SIZE_M <= 32 and BLOCK_M != 32: + continue + if SIZE_N <= 32 and BLOCK_N != 32: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def get_full_tuning_space(): + configs = [] + if not tuning_full_space: + return configs + + block_mn_range = [64, 128, 256] + block_k_range = [16, 32, 64] + num_warps_range = [1, 2, 4, 8] + # group_m_range = [0, 1, 2, 4, 8] + group_m_range = [0, 4, 8] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for num_stages in num_stage_range: + for num_waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append( + triton.Config( + { + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, + 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack + }, + num_stages=num_stages, + num_warps=num_warps, + )) + + return configs + + +#To do: we need update the default autotune configuration once we go through the whole performance test sets. +@triton.autotune( + configs=get_full_tuning_space() if tuning_full_space else [ + triton.Config( + { + 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=0), + triton.Config( + { + 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': + 16, 'kpack': 1 + }, num_warps=4, num_stages=4), + ], + key=['M', 'N', 'K'], + # prune_configs_by={ + # 'early_config_prune': prune_configs, + # 'perf_model': None, + # "top_k": None + # }, + reset_to_zero=['C'], +) +@triton.jit() +def streamk_gemm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, + # total_tiles_streamk, + total_programs_streamk, + two_tiles, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk = get_tile_config( + M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, total_programs_streamk) + + # Determine whether we are in the first wave or full_tiles phase based on pid + is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 + + # Calculate starting and ending iterations for first wave + if not is_first_wave: + tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + precomputed_stride_ak = BLOCK_K * stride_ak + precomputed_stride_bk = BLOCK_K * stride_bk + # pointers + A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += precomputed_stride_ak + B_BASE += precomputed_stride_bk + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_, acc) + else: + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + precomputed_stride_ak = BLOCK_K * stride_ak + precomputed_stride_bk = BLOCK_K * stride_bk + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += precomputed_stride_ak + B_BASE += precomputed_stride_bk + + # acc = acc.to(tl.float16) # restore C.dtype.element_ty + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, + None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + + def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): + total_blocks_M = triton.cdiv(M, BLOCK_M) + total_blocks_N = triton.cdiv(N, BLOCK_N) + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + + return total_blocking_tiles + + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + # GROUP_M = 8 # 0 to disable swizzling + + if matmul._debug: + total_blocks_M = triton.cdiv(M, BLOCK_M) + total_blocks_N = triton.cdiv(N, BLOCK_N) + iters_per_tile = triton.cdiv(K, BLOCK_K) + total_tiles = total_blocks_M * total_blocks_N + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + # total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + # total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + print(f"M,N,K={M},{N},{K} ; BLOCK_M,N,K={BLOCK_M},{BLOCK_N},{BLOCK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{total_partial_tiles_streamk=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + grids = lambda META: (total_programs_streamk + compute_total_blocking_tiles(M, N, META['BLOCK_M'], META[ + 'BLOCK_N'], two_tiles, total_programs_streamk), ) + kk = streamk_gemm[(grids)]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + # total_full_tiles_streamk=total_full_tiles_streamk, + # total_partial_tiles_streamk=total_partial_tiles_streamk, + # iters_per_tile=iters_per_tile, + # total_tiles_streamk=total_tiles_streamk, + total_programs_streamk=total_programs_streamk, + two_tiles=two_tiles, + ACC_TYPE=ACC_TYPE, + # GROUP_M=GROUP_M, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # BLOCK_K=BLOCK_K, + # num_stages=num_stages, + # num_warps=num_warps, + # waves_per_eu = waves_per_eu, + ) + if matmul._debug: + print(f"{kk.n_regs} registers used, {kk.n_spills} spills") + + # print(kk.asm['ttgir']) + # print(kk.asm['amdgcn']) + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 1792, 7424, 4864 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 4096, 4096, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +#A = torch.ones((m, k), device="cuda", dtype=torch.float16) +#B = torch.ones((k, n), device="cuda", dtype=torch.float16) +BLOCK_M = 256 +BLOCK_N = 256 +BLOCK_K = 64 +two_tiles = True +num_stages = 0 +num_warps = 8 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 1 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, + mfmaInstrSize, kpack) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" +print("pass validation test") + +# for debugging, uncomment the following line +#exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") +print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLOCK_M) * (n // BLOCK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, + num_stages, num_warps, waves_per_eu)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + Best_tuning_config = f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})' + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + "Best tuning config": Best_tuning_config, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py new file mode 100644 index 000000000000..2651ad59d923 --- /dev/null +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py @@ -0,0 +1,387 @@ +## matmul stream-k implementation +## Credit goes to @pommedeterresautee +## See https://github.com/openai/triton/issues/1393 + +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 +from typing import Optional + +import torch +import triton +import triton.language as tl +import random + +#from triton.runtime.driver import CudaUtils +import json + +torch.manual_seed(123) +random.seed(123) + +#device = torch.cuda.current_device() +#cuda_utils = CudaUtils() +#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] +#total_sm = 110 # for MI250 +total_sm = 304 # for MI300X +print(f"total SMs: {total_sm}") + +# --------------------------------------------------------------------------- +# Triton kernels +# --------------------------------------------------------------------------- + + +@triton.jit() +def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def first_wave( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + a = tl.load(A_BASE) + b = tl.load(B_BASE) + acc += tl.dot(a, b) + A_BASE += BLOCK_K * stride_ak + B_BASE += BLOCK_K * stride_bk + + if remainder == 0 and end_iter % iters_per_tile == 0: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + else: + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + start_iter = end_iter + + +# similar to the reference matmul kernel +@triton.jit() +def full_tiles( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + # rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C, acc) + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + +class matmul(torch.autograd.Function): + + _debug = True + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, + two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 4 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.zeros((M, N), device=device, dtype=a.dtype) + + k1 = first_wave[(total_programs_streamk, )]( + a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack) + if matmul._debug: + print(f"{k1.n_regs} registers used, {k1.n_spills} spills") + k2 = full_tiles[(total_blocking_tiles, )](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c.stride(0), c.stride(1), total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, + waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack) + if matmul._debug: + print(f"{k2.n_regs} registers used, {k2.n_spills} spills") +# print(k2.asm['amdgcn']) + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, + num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, + two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, + mfmaInstrSize=mfmaInstrSize, kpack=kpack) + + +# --------------------------------------------------------------------------- +# Example and Benchmark +# --------------------------------------------------------------------------- + +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + +#m, n, k = 4864, 4096, 8256 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test +#m, n, k = 8192, 8192, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) +#A = torch.ones((m, k), device="cuda", dtype=torch.float16) +#B = torch.ones((k, n), device="cuda", dtype=torch.float16) +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 +two_tiles = 'True' +num_stages = 0 +num_warps = 4 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 + +matmul.set_debug(True) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, + kpack) +#exit(0) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" + +# for debugging, uncomment the following line + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, + num_warps, waves_per_eu, mfmaInstrSize, kpack)) +print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, + waves_per_eu, mfmaInstrSize, kpack)) +print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") + +exit(0) +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), + torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + output: Optional[torch.Tensor] = None + + def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True, False]: + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // BLK_M) * (n // BLK_N) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul( + A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("results.json", "w") as f: + json.dump(results, f, indent=4) + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) From 17575ea88e229bbdbcd476553b3bc25b3b8dab58 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 13 May 2024 14:36:34 -0400 Subject: [PATCH 02/20] skip backward (#586) --- python/perf-kernels/flash-attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 6fc861b281fa..d70a43ecd36c 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1277,6 +1277,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 @pytest.mark.parametrize('use_alibi', [False, True]) def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): + pytest.skip() torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: seqlen_q = qseqlen_not_equal_kseqlen From a3d784a869aad6801694680f424c4f36e447db98 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 16 May 2024 15:20:39 -0500 Subject: [PATCH 03/20] Change all block pointers to tensor pointers (#585) Change all block pointers to tensor pointers Block pointers are for nvidia TMAs. They are useful for regular loads as well but not well supported. Also cleaned up some code I came across along the way and updated comment at the top. --- python/perf-kernels/flash-attention.py | 246 ++++++++++++------------- 1 file changed, 119 insertions(+), 127 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d70a43ecd36c..42e9ac310195 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -2,19 +2,21 @@ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: +Currently only the forward kernel is supported, and contains these features: -1) Non power of two head dims +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias """ @@ -28,10 +30,6 @@ torch_dtype: tl.constexpr = torch.float16 -TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') -if TORCH_HAS_FP8E5: - torch_dtype: tl.constexpr = torch.float8_e5m2fnuz - class MetaData(): cu_seqlens_q = None @@ -141,16 +139,22 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): return rng_keep +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor @@ -204,19 +208,26 @@ def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, - philox_seed, batch_philox_offset, encoded_softmax_block_ptr, block_min, block_max, offs_n_causal, - masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n @@ -238,8 +249,9 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -249,10 +261,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ # Compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) # softmax @@ -266,26 +276,26 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_ philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i @@ -364,7 +374,7 @@ def attn_fwd( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, + USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -375,6 +385,7 @@ def attn_fwd( off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -412,18 +423,20 @@ def attn_fwd( # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1)) + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - tl.store(l_ptrs, l) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) # TODO: Should dropout and return encoded softmax be handled here too? return @@ -434,41 +447,26 @@ def attn_fwd( else: off_h_k = off_h_q - # need_padding = False n_extra_tokens = 0 if seqlen_k < BLOCK_N: - # need_padding = True n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: - # need_padding = True n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr(base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1)) - v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - if BIAS_TYPE != 0: - b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs - bias_ptr = tl.make_block_ptr( - base=bias + b_offset, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn else: - bias_ptr = None + bias_ptrs = None if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah @@ -483,14 +481,11 @@ def attn_fwd( batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr(base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0)) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -499,8 +494,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -522,14 +520,16 @@ def attn_fwd( # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, + block_min, block_max, 0, 0, 0, alibi_slope, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N @@ -540,18 +540,20 @@ def attn_fwd( offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, - alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, + offs_n, # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -578,21 +580,20 @@ def attn_fwd( overflow_size = end_m_idx - seqlen_q if overflow_size > 0: boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - # This is a > check because mask being 0 blocks the store. - l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh - O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0)) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) @triton.jit @@ -941,7 +942,7 @@ def forward(ctx, q, k, v, o, metadata): encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, - BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1, + BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True, USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) @@ -1065,8 +1066,6 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") - # -1 because the last entry of cu_seqlens_q specifies the end of the last seq - # num_ctxs = len(cu_seqlens_q) - 1 # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() @@ -1114,9 +1113,6 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to else: alibi_slopes = None - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1150,11 +1146,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ (4, 48, 1024, 1024, 64), - (4, 24, 8192, 8192, 64), + (4, 12, 8192, 8192, 64), (2, 4, 16384, 16384, 128), (2, 16, 1020, 987, 128), (2, 16, 15498, 2, 128), - (2, 16, 7, 16219, 64), + (2, 4, 7, 16219, 64), (4, 48, 1, 1, 64), (4, 48, 1, 1, 128), (4, 48, 3, 3, 128), @@ -1164,12 +1160,12 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to (4, 4, 1024, 1024, 33), (4, 4, 65, 1019, 65), (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), ]) -@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): - pytest.skip() torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1185,9 +1181,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - if TORCH_HAS_FP8E5: - q = q.to(torch_dtype) - k = k.to(torch_dtype) o = torch.empty_like(q) # triton implementation @@ -1218,9 +1211,8 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - pytest.skip() + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) ref_out = torch.empty_like(q) @@ -1413,8 +1405,8 @@ def run_benchmark(custom): args = parse_args() dtype = arg_to_torch_dtype[args.dtype] - # hk = args.hq if not args.hk else args.hk - # sk = args.sq if not args.sk else args.sk + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d mode = 'fwd' x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] @@ -1422,7 +1414,7 @@ def run_benchmark(custom): varlen = args.varlen configs = [] if custom: - x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)] + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] else: if varlen: x_vals_list = varlen_benchmark_configs() From aa6685a16dde93b0c559f16f39cf0cf2994c27a9 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 20 May 2024 14:57:21 -0500 Subject: [PATCH 04/20] Add support for bshd layout (#587) Add support for layouts commonly used by users. Add option for varlen / thd layout to specify equal context lengths for all batches. Also often used by users. --- python/perf-kernels/flash-attention.py | 216 +++++++++++++------------ 1 file changed, 114 insertions(+), 102 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 42e9ac310195..d36caaf61952 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -28,8 +28,6 @@ import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 - class MetaData(): cu_seqlens_q = None @@ -41,6 +39,7 @@ class MetaData(): causal = False num_contexts = 0 varlen = False + layout = None dropout_p, return_encoded_softmax = 0.0, False def __init__(self, sm_scale=1.0): @@ -48,6 +47,7 @@ def __init__(self, sm_scale=1.0): def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): self.varlen = True + self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k # Without "varlen", there should still be one sequence. @@ -81,10 +81,10 @@ def need_dropout(self, dropout_p, return_encoded_softmax): def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) if self.varlen: assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape assert self.cu_seqlens_q is not None assert self.cu_seqlens_k is not None assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) @@ -95,8 +95,6 @@ def check_args(self, q, k, v, o): assert not self.return_encoded_softmax else: assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None assert k.shape == v.shape @@ -106,6 +104,8 @@ def check_args(self, q, k, v, o): assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit @@ -326,60 +326,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri use_cuda_graph=True, ) @triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - stride_az, - stride_ah, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - alibi_slopes, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - USE_ALIBI: tl.constexpr, - BATCH_SIZE: tl.constexpr, -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -876,6 +830,44 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, empty = torch.empty(128, device="cuda") +def get_shape_from_layout(q, k, metadata): + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + else: + assert False, "Got unsupported layout." + return batch, nheads_q, nheads_k, head_size + + +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(q, k, v, o, metadata): + if metadata.layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + class _attention(torch.autograd.Function): @staticmethod @@ -887,24 +879,14 @@ def forward(ctx, q, k, v, o, metadata): if o is None: o = torch.empty_like(q, dtype=v.dtype) metadata.check_args(q, k, v, o) - if metadata.varlen: - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = metadata.num_contexts - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) - v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) - o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) @@ -944,7 +926,7 @@ def forward(ctx, q, k, v, o, metadata): MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True, USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p - > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0]) + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -1036,30 +1018,41 @@ def backward(ctx, do, _): attention = _attention.apply -def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): torch.manual_seed(20) # Initialize q, k, v - q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout return q, k, v, input_metadata -def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): torch.manual_seed(20) # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) - max_seqlens_q = torch.max(seqlens_q).item() - max_seqlens_k = torch.max(seqlens_k).item() + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) # Calculate cumulative sequence lengths cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) @@ -1099,9 +1092,10 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_alibi', [True, False]) -def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16): +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) if causal: input_metadata.need_causal() @@ -1118,6 +1112,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to # triton implementation tri_out, _ = attention(q, k, v, o, input_metadata) + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], @@ -1141,6 +1140,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to p[nan_mask == 1] = 0 ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) @@ -1169,8 +1170,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = N_CTX_Q - input_metadata.max_seqlens_k = N_CTX_K + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') if causal: input_metadata.need_causal() if use_bias: @@ -1178,9 +1178,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None - q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() o = torch.empty_like(q) # triton implementation @@ -1211,6 +1208,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) @@ -1401,9 +1399,8 @@ def varlen_benchmark_configs(): return configs -def run_benchmark(custom): +def run_benchmark(custom, args): - args = parse_args() dtype = arg_to_torch_dtype[args.dtype] hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk @@ -1411,7 +1408,7 @@ def run_benchmark(custom): mode = 'fwd' x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal - varlen = args.varlen + varlen = args.layout == 'thd' configs = [] if custom: x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] @@ -1425,7 +1422,7 @@ def run_benchmark(custom): configs.append( triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], styles=[('red', '-')], ylabel='ms', - plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}', + plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}', args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode})) @triton.testing.perf_report(configs) @@ -1447,14 +1444,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal flops_per_matmul = 0 if varlen: - q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) for i in range(0, input_metadata.num_contexts): seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] # x2 for 2 GEMMs flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 else: - q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype) + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD if causal: input_metadata.need_causal() @@ -1479,6 +1477,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal bench_flash_attention.run(save_path=".", print_data=True) +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts + + def parse_args(): parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1489,11 +1496,14 @@ def parse_args(): parser.add_argument("-hk", type=int, default=0) parser.add_argument("-sq", type=int, default=0) parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') parser.add_argument("-d", type=int, default=0) parser.add_argument("-causal", action='store_true', default=False) - parser.add_argument("-varlen", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) return parser.parse_args() @@ -1503,6 +1513,8 @@ def parse_args(): def main(): args = parse_args() custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens, \ + "Equal sequence lengths arg must be used with the thd layout." if args.b or args.hq or args.hk or args.sq or args.sk or args.d: custom_config = True assert args.b and args.hq and args.sq and args.d, \ @@ -1513,7 +1525,7 @@ def main(): assert args.dtype in arg_to_torch_dtype, \ "Only fp16, bf16 and f32 types currently supported." - run_benchmark(custom_config) + run_benchmark(custom_config, args) if __name__ == '__main__': From dbe11738b9de976b4423db46faa94385a900ae6e Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 16 Jul 2024 19:22:27 -0400 Subject: [PATCH 05/20] Post-Merge CI (#612) * remove on push for Integration Tests * rename * add post merge test * save * dtype params * skip bad config * fix more stuff --- ... => amd_perf_kernel_Integration_tests.yml} | 8 +- .../amd_perf_kernel_postmerge_tests.yml | 92 +++++++++++++++++++ python/perf-kernels/flash-attention.py | 11 ++- 3 files changed, 101 insertions(+), 10 deletions(-) rename .github/workflows/{amd_perf_kernel_tests.yml => amd_perf_kernel_Integration_tests.yml} (95%) create mode 100644 .github/workflows/amd_perf_kernel_postmerge_tests.yml diff --git a/.github/workflows/amd_perf_kernel_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml similarity index 95% rename from .github/workflows/amd_perf_kernel_tests.yml rename to .github/workflows/amd_perf_kernel_Integration_tests.yml index 07424924a832..a8a8b3d50b9e 100644 --- a/.github/workflows/amd_perf_kernel_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -1,4 +1,4 @@ -name: AMD Perf Kernel Tests +name: AMD Perf Kernel Integration Tests on: workflow_dispatch: @@ -7,8 +7,6 @@ on: merge_group: branches: [main_perf] types: [checks_requested] - push: - branches: [main_perf] concurrency: group: ${{ github.ref }} @@ -36,8 +34,8 @@ jobs: changed_files=$(git diff --name-only origin/${{ github.base_ref }} ${{ github.sha }}) echo "Changed files:" echo "$changed_files" - if echo "$changed_files" | grep -v "^python/perf-kernels/"; then - echo "Changes detected outside of the python/perf-kernels directory. Failing the workflow." + if echo "$changed_files" | grep -vE "^python/perf-kernels/|^\.github/workflows/amd_"; then + echo "Changes detected outside of the python/perf-kernels directory or .github/workflows/amd_ files. Failing the workflow." exit 1 fi diff --git a/.github/workflows/amd_perf_kernel_postmerge_tests.yml b/.github/workflows/amd_perf_kernel_postmerge_tests.yml new file mode 100644 index 000000000000..40f211118541 --- /dev/null +++ b/.github/workflows/amd_perf_kernel_postmerge_tests.yml @@ -0,0 +1,92 @@ +name: AMD Perf Kernel Post-Merge Tests + +on: + workflow_dispatch: + push: + branches: [main_perf, micmelesse/post_merge_ci] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main_perf' }} + +permissions: read-all + +env: + TRITON_BUILD_WITH_CLANG_LLD: "TRUE" + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + TRITON_DISABLE_LINE_INFO: 1 + +jobs: + Runner-Preparation-AMD: + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then + echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]' + else + echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]' + fi + + PostMerge-Tests-AMD: + needs: Runner-Preparation-AMD + if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' + runs-on: ${{ matrix.runner }} + timeout-minutes: 30 + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} + container: + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Ensure the entire history is fetched for rebase + - name: Add upstream remote + run: | + git config --global --add safe.directory /__w/triton/triton + if [ $(git remote | grep -c upstream) -eq 0 ]; then + git remote add upstream https://github.com/triton-lang/triton.git + fi + git fetch upstream + - name: Rebase onto upstream/main + run: | + git config --global user.email "ci@amd.com" + git config --global user.name "Github Actions Post-Merge CI Script" + git rebase upstream/main || { echo "Rebase failed"; exit 1; } + - name: Show Git Log + run: | + echo "Git log after rebase from upstream/main to HEAD:" + git log $(git rev-parse upstream/main~2)..HEAD --oneline --graph --decorate + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Clear cache + run: | + rm -rf ~/.triton + mkdir -p ~/.triton + ls -alh ~/.triton + - name: Update PATH + run: | + echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install lit matplotlib pandas + - name: Install Triton + run: | + echo "PATH is '$PATH'" + pip uninstall -y triton + cd python + pip install -v -e . + - name: Run Perf Kernels Unit Tests + run: | + pytest -vvv ./python/perf-kernels/flash-attention.py + - name: Run Perf Kernels Benchmark + run: | + python ./python/perf-kernels/flash-attention.py diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index d36caaf61952..8177cf4ebf30 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -309,8 +309,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, @@ -1166,7 +1166,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) -def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1174,7 +1175,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor if causal: input_metadata.need_causal() if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None @@ -1197,7 +1198,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) From 23ba5467d83db1d8ca36f8ee34ff287ae089469c Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Thu, 18 Jul 2024 17:04:16 -0500 Subject: [PATCH 06/20] Increase CI timeout (#615) Increase CI timeout --- .github/workflows/amd_perf_kernel_Integration_tests.yml | 2 +- .github/workflows/amd_perf_kernel_postmerge_tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index a8a8b3d50b9e..956ff8903115 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -95,7 +95,7 @@ jobs: needs: Runner-Preparation-AMD if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 90 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} diff --git a/.github/workflows/amd_perf_kernel_postmerge_tests.yml b/.github/workflows/amd_perf_kernel_postmerge_tests.yml index 40f211118541..21470c094e46 100644 --- a/.github/workflows/amd_perf_kernel_postmerge_tests.yml +++ b/.github/workflows/amd_perf_kernel_postmerge_tests.yml @@ -36,7 +36,7 @@ jobs: needs: Runner-Preparation-AMD if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 90 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} From df4c4d3a7fa7a1329626972b36b9b5d8a84c75f2 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Fri, 19 Jul 2024 17:50:49 -0500 Subject: [PATCH 07/20] Couple of FA optimizations (#608) Couple of FA optimizations Set SM scale multiplication to a constexpr. Minor asm improvement. Changed acc scaling to adjust for softmax division to multiplication with reciprocal. ~10% perf improvement. --------- Co-authored-by: Michael Melesse --- python/perf-kernels/flash-attention.py | 39 ++++++++++++-------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 8177cf4ebf30..988438340abe 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -301,35 +301,28 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - # num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=8), - # TODO: This config fails with head_size not pow2 with data mismatches. Check why. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, + cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -446,13 +439,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - q = (q * qk_scale).to(q.type.element_ty) + q = (q * QK_SCALE).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -509,7 +502,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL) # epilogue - acc = acc / l_i[:, None] + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, @@ -1198,6 +1194,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) From 52a908fd512e1eb0790d2c48eae2b70e982e751d Mon Sep 17 00:00:00 2001 From: xiaohuguo2023 <149615094+xiaohuguo2023@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:33:02 +0100 Subject: [PATCH 08/20] streamk v0.1 (#619) * streamk v0.1 * remove unused variable * fix format issues * add README * fix format issue * change num_sms to num_cus --- .../03-matrix-multiplication-stream-k.py | 395 -------- ...trix-multiplication-stream-k-oldversion.py | 485 ---------- ...iplication-stream-k-singlekern-autotune.py | 563 ------------ ...ultiplication-stream-k-singleloop-nomod.py | 387 -------- python/perf-kernels/streamk/README.md | 43 + python/perf-kernels/streamk/streamk_kernel.py | 206 +++++ python/perf-kernels/streamk/tune_streamk.py | 847 ++++++++++++++++++ 7 files changed, 1096 insertions(+), 1830 deletions(-) delete mode 100755 python/perf-kernels/03-matrix-multiplication-stream-k.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py delete mode 100644 python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py create mode 100644 python/perf-kernels/streamk/README.md create mode 100644 python/perf-kernels/streamk/streamk_kernel.py create mode 100644 python/perf-kernels/streamk/tune_streamk.py diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py deleted file mode 100755 index 62d820719b9a..000000000000 --- a/python/perf-kernels/03-matrix-multiplication-stream-k.py +++ /dev/null @@ -1,395 +0,0 @@ -#!/usr/bin/env python -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - total_tiles_streamk, - total_programs_streamk, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers -# rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) -# rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - # start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_full_tiles_streamk=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - grids = total_programs_streamk + total_blocking_tiles - kk = streamk_gemm[(grids, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - num_stages=num_stages, - num_warps=num_warps, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -#m, n, k = 4096, 4096, 8192 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench( - lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py deleted file mode 100644 index beb8b0df9b1f..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py +++ /dev/null @@ -1,485 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -# iterate, multiply and accumulate over K axis -@triton.jit() -def mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - - # where are we in the grid - # tile_id = start_iter // iters_per_tile - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (start_iter % iters_per_tile) - # B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (start_iter % iters_per_tile) - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (mod1) - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (mod1) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - #if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM - - -# if mod2 == 0:# last iteration of the tile always happens before its start on another SM -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.store(C_, acc) -# if start_iter % iters_per_tile != 0: # only if tile has been partially processed -# if mod1 != 0: # only if tile has been partially processed -# tl.atomic_xchg(locks + tile_id, 1) -# else: -# while tl.atomic_cas(locks + tile_id, 1, 1) != 1: -# pass -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.atomic_add(C_, acc) - if mod1 == 0 and mod2 == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) - mod1 = start_iter % iters_per_tile - mod2 = end_iter % iters_per_tile - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M, - BLOCK_N, - BLOCK_K, - ACC_TYPE, - ) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = False - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 8 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - k1 = first_wave[(total_programs_streamk, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 128 -BLK_N = 256 -BLK_K = 16 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // 128) * (n // 128) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py deleted file mode 100644 index a35d691a0225..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py +++ /dev/null @@ -1,563 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") -# global flag to indicate whether using the full tuing space -tuning_full_space = True -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def get_tile_config(M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, two_tiles, - total_programs_streamk): - total_blocks_M = tl.cdiv(M, BLOCK_M) - total_blocks_N = tl.cdiv(N, BLOCK_N) - iters_per_tile = tl.cdiv(K, BLOCK_K) - # GROUP_M = 0 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - return iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk - - -# pruned some unreasonable config -def prune_configs(configs, named_args): - # call only for full tuning space - if not tuning_full_space: - return configs - - SIZE_M = named_args["A"].shape[0] - SIZE_N = named_args["B"].shape[1] - # SIZE_K = named_args["A"].shape[1] - - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, _ =\ - kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_K"] - if SIZE_M <= 32 and BLOCK_M != 32: - continue - if SIZE_N <= 32 and BLOCK_N != 32: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def get_full_tuning_space(): - configs = [] - if not tuning_full_space: - return configs - - block_mn_range = [64, 128, 256] - block_k_range = [16, 32, 64] - num_warps_range = [1, 2, 4, 8] - # group_m_range = [0, 1, 2, 4, 8] - group_m_range = [0, 4, 8] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for num_stages in num_stage_range: - for num_waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append( - triton.Config( - { - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, - 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack - }, - num_stages=num_stages, - num_warps=num_warps, - )) - - return configs - - -#To do: we need update the default autotune configuration once we go through the whole performance test sets. -@triton.autotune( - configs=get_full_tuning_space() if tuning_full_space else [ - triton.Config( - { - 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=4), - ], - key=['M', 'N', 'K'], - # prune_configs_by={ - # 'early_config_prune': prune_configs, - # 'perf_model': None, - # "top_k": None - # }, - reset_to_zero=['C'], -) -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, - # total_tiles_streamk, - total_programs_streamk, - two_tiles, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk = get_tile_config( - M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, total_programs_streamk) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - # pointers - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - - def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - - return total_blocking_tiles - - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - # GROUP_M = 8 # 0 to disable swizzling - - if matmul._debug: - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - iters_per_tile = triton.cdiv(K, BLOCK_K) - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - # total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - # total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - print(f"M,N,K={M},{N},{K} ; BLOCK_M,N,K={BLOCK_M},{BLOCK_N},{BLOCK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - grids = lambda META: (total_programs_streamk + compute_total_blocking_tiles(M, N, META['BLOCK_M'], META[ - 'BLOCK_N'], two_tiles, total_programs_streamk), ) - kk = streamk_gemm[(grids)]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - # total_full_tiles_streamk=total_full_tiles_streamk, - # total_partial_tiles_streamk=total_partial_tiles_streamk, - # iters_per_tile=iters_per_tile, - # total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - two_tiles=two_tiles, - ACC_TYPE=ACC_TYPE, - # GROUP_M=GROUP_M, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # BLOCK_K=BLOCK_K, - # num_stages=num_stages, - # num_warps=num_warps, - # waves_per_eu = waves_per_eu, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 1792, 7424, 4864 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 4096, 4096, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLOCK_M = 256 -BLOCK_N = 256 -BLOCK_K = 64 -two_tiles = True -num_stages = 0 -num_warps = 8 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 1 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, - mfmaInstrSize, kpack) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -#exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLOCK_M) * (n // BLOCK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - Best_tuning_config = f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})' - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - "Best tuning config": Best_tuning_config, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py deleted file mode 100644 index 2651ad59d923..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py +++ /dev/null @@ -1,387 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - - k1 = first_wave[(total_programs_streamk, )]( - a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c.stride(0), c.stride(1), total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, - waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") -# print(k2.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul( - A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/README.md b/python/perf-kernels/streamk/README.md new file mode 100644 index 000000000000..aa0b11d41b73 --- /dev/null +++ b/python/perf-kernels/streamk/README.md @@ -0,0 +1,43 @@ +# streamk gemm script v0.1 + +The plan is to use this version as the base version for the future triton streamk gemm development. + +### Main features +- comparable performance with tune gemm + +- use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile. + +- use atomics for spinning lock to replace atomic_add for the final output. + +- pid renumbering based on chiplet structure of MI300X + +- dynamic grid setting + +- tuning script adapt from tune_gemm + +### Usage + +Go to the script dir +```bash +cd triton/python/perf_kernels/streamk +``` + +1. Tune gemm sizes given in a yaml file and check correctness on the way +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --compare +``` + +2. Tune a single gemm size +```bash +python tune_streamk.py -m 16 -n 16 -k 16 +``` + +3. Choose the file to store tuning results +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml +``` + +4. Only check correctness given the tuning results +```bash +python tune_streamk.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +``` diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py new file mode 100644 index 000000000000..138e6540e203 --- /dev/null +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -0,0 +1,206 @@ +import triton +import triton.language as tl + + +@triton.jit() +def get_new_pid(current_pid, num_cus): + # Number of XCDs + num_xcds = 8 + # Number of pids per XCD in the new arrangement + pids_per_xcd = num_cus // num_xcds + # Compute current XCD and local pid within the XCD + xcd = current_pid % num_xcds + local_pid = current_pid // num_xcds + + # Calculate new pid based on the new grouping + new_pid = xcd * pids_per_xcd + local_pid + return new_pid + + +@triton.jit() +def get_tiles_config( + M, + N, + K, + num_cus, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + total_blocks_M = tl.cdiv(M, BLOCK_SIZE_M) + total_blocks_N = tl.cdiv(N, BLOCK_SIZE_N) + iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K) + + total_tiles = total_blocks_M * total_blocks_N + if num_cus > 0 and total_tiles > num_cus: # Stream-K + total_streamk_tiles = total_tiles % num_cus + total_full_tiles = total_tiles - total_streamk_tiles + total_streamk_iters = total_streamk_tiles * iters_per_tile + # iterations related to full waves + streamk_iters_pcu = total_streamk_iters // num_cus + # iterations related to last (partial) wave + streamk_remainder_iters = total_streamk_iters % num_cus + + else: # all tiles are computed using classical blocking + total_full_tiles = total_tiles + total_streamk_tiles = 0 + streamk_iters_pcu = 0 + streamk_remainder_iters = 0 + total_streamk_iters = 0 + + return iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters + + +@triton.jit() +def streamk_gemm( + A, + B, + C, + P, + locks, + M, + N, + K, + num_cus, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + pid = get_new_pid(pid, num_cus) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config( + M, N, K, num_cus, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + rk = tl.arange(0, BLOCK_SIZE_K) + + for tile_id in range(pid, total_full_tiles, num_cus): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + a = tl.load(A_BASE, mask=rk[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + c = acc.to(C.type.element_ty) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters) + last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum( + pid + 1, streamk_remainder_iters) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + # rk = tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for current_iter in range(start_iter, end_iter): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + global_k_offset = (current_iter % iters_per_tile) * BLOCK_SIZE_K + k_mask = global_k_offset + rk < K + a = tl.load(A_BASE, mask=k_mask[None, :], other=0.0) + b = tl.load(B_BASE, mask=k_mask[:, None], other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + tile_iter = tile_id * iters_per_tile + if start_iter == tile_iter: + tile_iter_end = tile_iter + iters_per_tile + next_pid = pid + 1 + end = end_iter + while (end < tile_iter_end and next_pid < num_cus): + # todo: try use tl.load once cache modifier landed upstream + while tl.atomic_cas(locks + next_pid, 1, 1) != 1: + pass + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + acc += tl.load(P_) + end += streamk_iters_pcu + (next_pid < streamk_remainder_iters) + + next_pid += 1 + + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + else: + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + tl.store(P_, acc) + tl.atomic_xchg(locks + pid, 1) + + start_iter = end_iter diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py new file mode 100644 index 000000000000..3b0fbdb960c7 --- /dev/null +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -0,0 +1,847 @@ +# fp8 +import argparse +import sys +import yaml +import os +import glob +import subprocess + +import torch +import triton +import triton.language as tl + +from streamk_kernel import streamk_gemm + +from datetime import datetime +import multiprocessing +import pandas as pd + +device_oi = 650. / 3.0 + + +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, + 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, + 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, + 'kpack': kpack + }) + + return configs + + +def get_gemm_oi(M, N, K): + FLOPs = 2 * M * N * K + # 4 for fp32 + # to do check dtype for bytesmoved + bytesmoved = (M * K + K * N + 2 * M * N) * 4 + return FLOPs / bytesmoved + + +def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): + pruned_configs = [] + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if BLOCK_SIZE_K == 16 and matrix_instr_nonkdim == 16 and kpack == 2: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') + return None + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, + dtype_lock): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + torch_dtype_p = 'fp32' + torch_dtype_lock = 'int32' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + if dtype_p: + torch_dtype_p = tl_to_torch_types[name_to_tl_types[dtype_p]] + if dtype_lock: + torch_dtype_lock = tl_to_torch_types[name_to_tl_types[dtype_lock]] + configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, am, ak, bk, bn, cm, cn, warmup=False): + grid = num_cus + #print(f'config: streamk_gemm_{configStr}', flush=True) + if warmup: + streamk_gemm_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_p}, {torch_dtype_lock}, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K}, + grid=(1,) + ) + return None + else: + streamk_gemm_{configStr}[grid,]( + a, b, c, P, locks, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K} + ) + return c + +def try_config_{configStr}(M, N, K, num_cus, am, ak, bk, bn, cm, cn): + try: + matmul_{configStr}(None, None, None, None, None, M, N, K, num_cus, am, ak, bk, bn, cm, cn, True) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + return configStr, matmul_def_str + + +def generated_kernel_name(M, N, K, gpu_id): + return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py" + + +# Open {len(gpus)} files +# generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py +# and generate +# 1. matmul kernels of all configs +# 2. wrapper function matmul to invoke all the generated kernels +# 3. Another wraper function try_config to invoke matmul function +# 4. test_gemm to invoke +# 4.1 run try_config in parallel +# 4.2 matmul in a loop of 10 iterations +def generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench): + filenames = [] + for i in range(jobs): + filenames.append(generated_kernel_name(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = """import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_streamk import gen_input +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + # write definitions of streamk_gemm_xxx + # and matmul_xxx and try_config + with open("streamk_kernel.py") as file: + streamk_gemm_code = file.read() + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, + dtype_b, dtype_c, dtype_p, dtype_lock) + # Copy the streamk_gemm with name replaced + streamk_gemm_config = streamk_gemm_code.replace("streamk_gemm", f"streamk_gemm_{configStr}") + streamk_gemm_config = streamk_gemm_config.replace("import triton.language as tl", "") + streamk_gemm_config = streamk_gemm_config.replace("import triton", "") + f_kernel[file_idx].write(streamk_gemm_config + "\n\n") + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + test_gemm_pre_str = f"""def test_gemm(M, N, K, num_cus, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda') + b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda') + c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]}) + task_args = (M, N, K, num_cus, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1)) + + if num_threads > 1: + results = [] + config_names = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # warm up call of all matmul functions in parallel + idx = 0 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel[idx % jobs].write(task_str) + idx += 1 + + for fi in range(jobs): + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") + else: + try: + with open("{filename}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] + """.format(filename=filenames[fi]) + f_kernel[fi].write(threadpool_str) + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + print(f"{configStr}") + for i in range({runs}): + locks = torch.zeros((num_cus,), device = "cuda", dtype = torch.int32) + P = torch.zeros((num_cus, {block_m}*{block_n}), device="cuda", dtype=torch.float32) + d = matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = """ +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + args = parser.parse_args() + numThreads = args.n + num_cus = 304 + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, num_cus, numThreads)' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() + + +def extract_kernel_time(M, N, K, num_cus, EVEN_K, config, df): + # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 + # once the bug is fixed, we should not need below two lines + cols = [ + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', + 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + ] + df.columns = cols + + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) + + filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() + filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] + meanTime = filtered_df['DurationNs'].tail(100).mean() + return config, meanTime + + +def profile_batch_kernels(M, N, K, num_cus, gpuid, gpus, jobs, verbose): + ngpus = len(gpus) + gpuIdx = gpus.index(gpuid) + if gpuIdx + 1 > jobs: + return + os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) + jobId = gpuIdx + while jobId < jobs: + if verbose: + print(f"profiling {generated_kernel_name(M, N, K, jobId)} on GPU {gpuid}") + run_bash_command_wrapper( + f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {generated_kernel_name(M, N, K, jobId)}", + capture=(verbose < 2)) + jobId += ngpus + + +def tune_gemm_config(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=16, gpus=[0]): + # Generate kernel out of all configs + generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench) + + # remove any compiled kernel in the cache + run_bash_command("rm -rf ~/.triton/cache") + + # precompile the kernels in parallel + start_time = datetime.now() + if not skipWarmup: + for i in range(jobs): + run_bash_command(f"python {generated_kernel_name(M, N, K, i)} -n {num_threads}", capture=(verbose < 2)) + compile_end = datetime.now() + compile_time = compile_end - start_time + if verbose: + print(f"compile time: {compile_time}", flush=True) + + # profile generated kernels + running = [ + multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, num_cus, gpu_id, gpus, jobs, verbose)) + for gpu_id in gpus + ] + for p in running: + p.start() + for p in running: + p.join() + + profile_end = datetime.now() + profile_time = profile_end - compile_end + if verbose: + print(f"profile time: {profile_time}", flush=True) + + # post process results.csv to get the best config and minTime + # TODO: process the file in parallel + minTime = 1024 * 1024 * 1024 + thread_pool = multiprocessing.Pool(processes=num_threads) + tasks = [] + idx = 0 + df_prof = [ + pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') + for i in range(jobs) + ] + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + file_idx = idx % jobs + tasks += [ + thread_pool.apply_async(extract_kernel_time, args=(M, N, K, num_cus, EVEN_K, config, df_prof[file_idx])) + ] + idx += 1 + thread_pool.close() + thread_pool.join() + + for task in tasks: + config, myTime = task.get() + if myTime: + min_us = myTime / 1000 + if min_us < minTime: + minTime = min_us + bestConfig = config + else: + min_us = -1 + print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True) + post_end = datetime.now() + post_time = post_end - profile_end + if verbose: + print(f"post procesing time: {post_time}", flush=True) + return minTime, bestConfig, compile_time, profile_time, post_time + + +def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + def init_by_size_and_type(size, dtype, init_type): + if init_type == 'hpl': + return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) + # This init type has element[i] in row[j] equal to sin(i+j*N) + elif init_type == 'trig_float': + M, N = size + return torch.reshape(torch.arange(0, M * N), (M, N)).sin().to(dtype=dtype, device='cuda') + elif init_type == 'zeros': + return torch.zeros(size, dtype=dtype, device='cuda') + elif init_type == "randn": + temp = torch.randn(size, dtype=dtype, device='cuda') + return temp + else: + raise ValueError("Bad matrix initialization type.") + + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), torch.float32, init_type) + if needTrans: + raw_data = raw_data.T + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +def matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, EVEN_K): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + #assert a.is_contiguous(), "Matrix A must be contiguous" + #assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + + grid = num_cus + + streamk_gemm[ + grid, + ](a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, num_warps=num_warps, + num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, EVEN_K=EVEN_K) + return c + + +def test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch.manual_seed(0) + #a = torch.randn((M, K), device='cuda', dtype=datatype) + #b = torch.randn((K, N), device='cuda', dtype=datatype) + a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda') + b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') + # Allocates output. + print(f"{block_k}") + EVEN_K = K % block_k == 0 + c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + locks = torch.zeros((num_cus, ), device="cuda", dtype=torch.int32) + P = torch.zeros((num_cus, block_m * block_n), device="cuda", dtype=torch.float32) + triton_output = matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, + waves_per_eu, mfmaInstrSize, kpack, EVEN_K) + torch_output = torch.matmul(a_fp16, b_fp16) + # print(f"triton_output={triton_output}") + # print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + atol = 1e-3 + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = '' + if verbose: + size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' + if torch.allclose(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol): + print(f'{size_str} Correct✅') + else: + print(f'{size_str} Incorrect❌') + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') + parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') + parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") + parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") + parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") + parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], + help='list of gpu ids to use for tuning') + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--o", type=str, default=get_default_tuning_result_filename(), + help='yaml file to store tuning results') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--compare_wo_tuning", action='store_true', default=False, + help="Whether check result correctness") + parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") + parser.add_argument("--time_breakdown", action='store_true', default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument("--verbose", action='store_true', default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=16, + help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--iters", type=int, default=1000, help="number of generated files") + parser.add_argument("--init_type", type=str, default='randn', + help="Initialization type for input matrices (default uniform rand [0, 1.0)])") + parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + args = parser.parse_args() + + return args + + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def process_item(item): + M = item['M'] + N = item['N'] + K = item['K'] + col_a = False if item['rowMajorA'] == 'T' else True + col_b = False if item['rowMajorB'] == 'T' else True + del item['M'] + del item['N'] + del item['K'] + del item['rowMajorA'] + del item['rowMajorB'] + return M, N, K, col_a, col_b, item + + +def type_name_to_bytes(ty_name): + if '32' in ty_name: + return 4 + if '16' in ty_name: + return 2 + if '8' in ty_name: + return 1 + else: + print(f"Unrecognized input type name {ty_name}") + sys.exit(1) + + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted + + +def main(): + args = parse_args() + matrix_size_file = args.gemm_size_file + tuning_output_file = args.o + keepTmp = args.keep + run_bench = args.benchmark + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + num_cus = 304 + + # Get GPU ids + ngpus = args.ngpus + gpu_ids = args.gpu_ids + if ngpus != 0 and gpu_ids: + print("--ngpus and --gpu_ids are mutually exclusive options") + return os.EX_USAGE + if ngpus == 0 and not gpu_ids: + ngpus = 1 + if ngpus != 0: + gpus = range(ngpus) + if gpu_ids: + gpus = gpu_ids + + if run_bench: + gpus = [gpus[0]] + jobs = 1 + + # Get element type + dtype_a = args.dtype_a + dtype_b = args.dtype_b + dtype_c = args.dtype_c + dtype_p = 'fp32' + dtype_lock = 'int32' + if dtype_a not in name_to_tl_types or dtype_b not in name_to_tl_types or dtype_c not in name_to_tl_types: + print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}") + print("Supported types: ", list(name_to_tl_types.keys())) + sys.exit(1) + + mnks = [] + # TODO: make it more robust to get user input + init_type = args.init_type + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + M = args.m + N = args.n + K = args.k + col_a = args.col_a + col_b = args.col_b + mnks = [(M, N, K, col_a, col_b, None)] + else: + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + for item in matrix_sizes: + M, N, K, col_a, col_b, item = process_item(item) + mnks.append((M, N, K, col_a, col_b, item)) + + # Check correctness from given configs + if args.compare_wo_tuning: + for (M, N, K, col_a, col_b, myConfig) in mnks: + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, True) + return + + configs_full = get_full_tuning_space() + + start_time = datetime.now() + if run_bench: + print(f"Benchmarking gemm with {dtype_a} inputs") + print("trans M N K TFLOPS us") + else: + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) + f_results = open(tuning_output_file, 'w') + + for (M, N, K, col_a, col_b, myConfig) in mnks: + start_local_time = datetime.now() + # Obtain a pruned tuning space according to gemm size + # If running benchmark, use the provided config + pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) + + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' + if not run_bench: + print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + else: + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") + + # The main tuning funtion for one gemm size + verbose_level = 0 + if args.time_breakdown: + verbose_level = 1 + if args.verbose: + verbose_level = 2 + minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( + M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, + run_bench, jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level) + + EVEN_K = True if K % bestConfig.get('BLOCK_SIZE_K') == 0 else False + # post processing the numbers + perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) + tri_tflops = perf_tflops(minTime) + formatted_tflops = format_output(tri_tflops) + minTime = format_output(minTime) + if not run_bench: + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) + + bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, bestConfig, None, + None, None, None, None) + if not run_bench: + print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) + + # write best config to tuning_results.yaml + if run_bench: + print(f"{formatted_tflops} {minTime}") + + sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} + sizeDict.update(bestConfig) + if not run_bench: + f_results.write("- " + str(sizeDict) + " ") + f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + + # remove generated files if asked to + if not keepTmp: + for i in range(jobs): + generated_script = generated_kernel_name(M, N, K, i) + os.remove(generated_script) + if not skipWarmup: + os.remove(generated_script + ".failed_configs") + for f in glob.glob(f"results_{i}.*"): + os.remove(f) + + # Check correctness if asked to + if args.compare: + print("correctness: ", end=" ", flush=True) + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + elif not run_bench: + print("", flush=True) + + end_local_time = datetime.now() + if not run_bench: + print( + f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", + flush=True) + + if not run_bench: + f_results.close() + + end_time = datetime.now() + tuning_time = end_time - start_time + if not run_bench: + print(f"Tuning ends at: {end_time}") + print(f"Total tuning time (h:m:s): {tuning_time}") + + +if __name__ == '__main__': + sys.exit(main()) From 1d2e06681f9cb912086cade50cc53338d2a490b8 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Tue, 6 Aug 2024 14:20:12 -0300 Subject: [PATCH 09/20] Add explicit multiply-reduce GEMM kernel (#621) * Add explicit multiply-reduce GEMM kernel * Remove `SPLIT_K` argument from kernel * Remove `GROUP_SIZE_M` argument from kernel * Remove conditional call to `tl.dot` from kernel * Remove table with performance data from README --- python/perf-kernels/README.md | 8 ++++ .../perf-kernels/multreduce_matmul_kernel.py | 45 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 python/perf-kernels/multreduce_matmul_kernel.py diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index 5bcedbf49cdd..b8f930ef94ea 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -61,3 +61,11 @@ fp32, bf16 and f8 (both e5m2 and e4m3) datatypes. ## `03-matrix-multiplication-stream-k.py` This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598) + +## `multreduce_matmul_kernel.py` + +Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. Such +small block sizes aren't natively supported by `tl.dot` operator. + +Despite being numerically correct, this kernel performed worse than a corresponding GEMM kernel that +used `tl.dot` with minimum block size equal to $16$. diff --git a/python/perf-kernels/multreduce_matmul_kernel.py b/python/perf-kernels/multreduce_matmul_kernel.py new file mode 100644 index 000000000000..61535d5bcdd3 --- /dev/null +++ b/python/perf-kernels/multreduce_matmul_kernel.py @@ -0,0 +1,45 @@ +import triton +import triton.language as tl + + +# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. +# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch). +@triton.jit +def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + if BIAS: + bias_ptrs = bias_ptr + offs_am * stride_bias + bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) + acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # Dot product implemented as explicit multiply-reduce: + a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype) + b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype) + accumulator += tl.sum(a * b, axis=1) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(c_ptr.type.element_ty) + if BIAS: + c += bias[:, None] + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) From 11e4447d42e3d5ed9f8b1214b52fd57f5ae74d1d Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Tue, 13 Aug 2024 15:28:31 -0300 Subject: [PATCH 10/20] Copy *tune_gemm* from `triton-mlir` branch to `main_perf` branch (#614) * Copy *tune_gemm* from `triton-mlir` branch to `main_perf` branch The source commit in `triton-mlir` branch is the following one: ``` commit cf44637139ba441342b909977f51f7cfec2c9963 Author: Lixun Zhang Date: Tue Jul 23 14:22:01 2024 -0500 [tuning] gemm tuning script v3.3 (#606) ``` *tune_gemm* was copied from the source branch directory `scripts/amd/gemm` to the destination branch directory `python/perf-kernels/tune_gemm`. The SHA-256 hashes of *tune_gemm* files are the following ones: ``` 423aef1deb6c60f6578a1ecfc94d2473f8746b00d0368c553d31641fcfa5e354 README.md 46ab93978fee33f75df23332f12546dae7910478c391f08b7b1ebd415d8266b7 icache_flush.py f18711544641b810a652e6a6629bfa2b613f6ade87399e88fdf05b81d4af58a4 matmul.py 84a1c80ede36d3154e51188276eda2d2d0f52ed4f496ff69349c390d83b8ec10 matmul_kernel.py 2812b40183637bc8d7e47d283c7d66b1792134a43de76f3eacf7b9b3e1c2431a one_config.py 0ac09c33b0173cea06ddabbf9f4e3afa1816781dea4fdcce5894a7e7d6a80e19 rocprof_gemm.py 00eff41cf1c0bfc41d623e42b51706af67639fec76146741e2067d2a93e0148a utils/file_generator.py cb7afb773ccee835b00396cccf87e0d44fe513131161f031fae42453725b3c82 utils/utils.py 59f23811b660e49e566927853926a21f02a7014bb19c8ea67e6b382db6c59900 tune_gemm.py e787f35d750b869f113b3c01692f64243a9cb8a71a18ade2f0465f614f7284e4 tune_gemm.sh ``` The files were kept as-is despite `pre-commit` intentions to change them. After that, *tune_gemm* directory in code and documentation was fixed to reflect it's new location. --- python/perf-kernels/tune_gemm/README.md | 316 ++++++ python/perf-kernels/tune_gemm/icache_flush.py | 94 ++ python/perf-kernels/tune_gemm/matmul.py | 375 +++++++ .../perf-kernels/tune_gemm/matmul_kernel.py | 64 ++ python/perf-kernels/tune_gemm/one_config.py | 90 ++ python/perf-kernels/tune_gemm/rocprof_gemm.py | 318 ++++++ python/perf-kernels/tune_gemm/tune_gemm.py | 937 ++++++++++++++++++ python/perf-kernels/tune_gemm/tune_gemm.sh | 27 + .../tune_gemm/utils/file_generator.py | 355 +++++++ python/perf-kernels/tune_gemm/utils/utils.py | 115 +++ 10 files changed, 2691 insertions(+) create mode 100644 python/perf-kernels/tune_gemm/README.md create mode 100644 python/perf-kernels/tune_gemm/icache_flush.py create mode 100644 python/perf-kernels/tune_gemm/matmul.py create mode 100644 python/perf-kernels/tune_gemm/matmul_kernel.py create mode 100644 python/perf-kernels/tune_gemm/one_config.py create mode 100755 python/perf-kernels/tune_gemm/rocprof_gemm.py create mode 100755 python/perf-kernels/tune_gemm/tune_gemm.py create mode 100755 python/perf-kernels/tune_gemm/tune_gemm.sh create mode 100644 python/perf-kernels/tune_gemm/utils/file_generator.py create mode 100644 python/perf-kernels/tune_gemm/utils/utils.py diff --git a/python/perf-kernels/tune_gemm/README.md b/python/perf-kernels/tune_gemm/README.md new file mode 100644 index 000000000000..5a986a9f987d --- /dev/null +++ b/python/perf-kernels/tune_gemm/README.md @@ -0,0 +1,316 @@ +# GEMM tuning script (current v3.3) + +## matmul kernel + +The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tune_gemm/matmul_kernel.py), which includes the following features: +- grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that +implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations). +- split-k algorithm, which is controlled by `SPLIT_K`. +- Bias along M dim, which is controlled by `BIAS` and `bias_ptr`. +- Masked load along K dim inside the loop, which is controlled by `EVEN_K`. +This means `BLOCK_SIZE_K` does not need to divide K dim. + +### Differences between the tutorial + +Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial), +the matmul kernel used in the tuning script (referred as the kernel) does not +guard load along M and N dim +([this](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py#L282-L283) shows how this is done in the tutorial). +When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will +load out-of-bound data. +In most cases this is fine, since the kernel does masked store at the end. +However, this may lead to GPU memory access fault in some cases, especially +when the tensor is large. +We will fix this issue in the future. + + +## Tuning script usage + +### Tuning mode + +The tuning script can take one or more gemm sizes and run tuning for them. +The input gemm sizes are prepared in a yaml file. Here is an example yaml file: +```yaml +- {'M': 4864, 'N': 4096, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N'} +- {'M': 512, 'N': 512, 'K': 512, 'rowMajorA': 'T', 'rowMajorB': 'N'} +``` + +The tuning script works as follows +```python +./tune_gemm.py --gemm_size_file input.yaml [options] +``` +The following `options` are supported in the tuning mode + +- Input data types: + - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: input and output element type. + - Supported `dtype`: fp16 (default), bf16, fp8, bf8, int8, int32, fp32 +- Parallel compilation of kernels: + - `num_threads n` controls that n threads will + be used in the compilation stage. The default value is 32. + - `--no_warmup` can be used to skip the compilation stage. Thus kernels will be + compiled during the profiling stage. This increases tuning time. But it's + required for some old torch version, in which some function used in the warmup + kernel launch is not supported. +- Parallel profiling of kernels: The tuning space is first divided into a number +of tasks, which is controlled by `--jobs n`. And all the tasks can be profiled in +parallel on a number of GPUs in the system. There are two ways to specify which +GPU(s) we want to use for profiling. Note that these flags cannot be use together. +By default, only one task is generated and profiled on GPU0. + - `--ngpus n`: GPU 0,1,.., n-1 will be used. + - `--gpu_ids ids`: `ids` are comma separated gpu ids and GPUs in `ids` will be used. +- General tuning control flags + - `--init_type INIT_TYPE` defines how input data are initialized. `INIT_TYPE` can be + - hpl: uniform distribution between -.5 and .5 + - trig_float: the distribution of elements in the flattened tensor follow + the `sin` function. + - zeros: initialize all data as 0, i.e. `torch.zeros` + - randn (default): normal distribution, i.e. `torch.randn` + - `--rotating_tensor SIZE`: provide the size of memory used for rotating tensor. + The default is 0, meaning rotating tensor is not used. + - `--icahe_flush`: If true, the script will generate a kernel to flush i-cache. + The default is False. + - `--bias_vector`: If true, a bias vector along the M dim is applied. + The default is False. +- Correctness check + - `--compare` will check the correctness of the best config for each gemm size. + - `--compare_wo_tuning` will check the correctness of the config provided in + the yaml file. If this is set, user needs to provide all the parameters in + the input yaml file. Example can be found in the benchmark mode section. +- Logistics + - `--keep` can be used to keep the files generated during the tuning process. + Be default, intermediate files are removed at the end. + - `--time_breakdown`: If set, the script will print out elapsed time during + each stage of the tuning in real-time. The default is False. + - `--verbose` will enable more logging message than `--time_breakdown`, such + as output from rocprofv2 + - `--o OUTPUT` can be used to control the output filename to store the tuning + result. The default filename is `tuning_results_branchName@gitCommit_timeStamp.yaml`. + Therefore, each time the user runs the tuning script, a different output file + will be generated. +- Hacks + - `--hack_triton_compiler`: If set, the triton source code will be modified + to provide a static backend target so that the compiler will not query + GPU information. This makes sure that during the compilation stage, no + hip runtime kernels are launched. + Note that this is a very hacky option, because + - It modifies the triton compiler directly, which is located from + `pip show triton`. + - It does string match and replace to modify the code. + - It does not restore the code when the tuning session terminates. + +Here are some example usages of running the script for tuning: + +Tune some gemm sizes with f16 input +```python +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 --o output.yaml +``` +It's recommended to use as many GPUs as possible and set `--jobs` to +a value that is 4 to 6 times the number of GPUs. + +If you are only allowed to use a subset of the GPUs, you can +```python +./tune_gemm.py --gemm_size_file input.yaml --gpu_ids 0,1,3,4 --jobs 32 --o output.yaml +``` +This runs the profiling on GPU 0,1,3,4. + +For bf8 input +```python +./tune_gemm.py --gemm_size_file input.yaml --ngpus 8 --jobs 32 -dtype_a bf8 -dtype_b bf8 +``` + +Check correctness of the tuned configs +```python +./tune_gemm.py --gemm_size_file output.yaml --compare_wo_tuning +``` + + +### Benchmark mode + +In benchmark mode, the script will run a single given config multiple times to +collect performance data. The benchmark mode works as +The tuning script works as follows +```python +./tune_gemm.py --gemm_size_file input.yaml [options] --benchmark +``` +The supported `options` are as followings +- `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode. +- `--iters n` controls the number of iterations to run the kernel. +The default value is 1000. +- `--icahe`: same as tuning mode +- `--rotating_tensor SIZE`: same as tuning mode + + +## Tuning script implementation overview + +The general idea of the tuning script can be summarized as +- Compile all the kernels in the tuning space in parallel. +- Divide the tuning space into tasks and invoke `rocprofv2` once per +task. This will save invocation overhead of the profiler. +- Profile tasks in parallel on multiple GPUs. + +For detailed implementation, please refer to the changelog of each version. + + +# Changelog + +## GEMM tuning script v1 + +Shucai (@scxiao) implemented the first version of gemm tuning script: https://github.com/ROCmSoftwarePlatform/triton/pull/309 + +## GEMM tuning script v2 + +This version is based on v1 and @alefimov-amd's thread pool https://github.com/ROCmSoftwarePlatform/triton/pull/310 + +### Main features +- `rocprof` is used to measure the time for kernels in the full tuning space +- Each kernel is executed 10 times and the execution time of the last instance is used +- All kernels are compiled in parallel +- Two modes for correctness checking + - During tuning, check correctness with the best perf_config for the current gemm size + - Without tuning, check correctness based on the tuning results, which includes best perf_config for each gemm size +- The process takes about 30 - 40 minutes for the full tuning space with ~15000 configs +- Limitations + - For now, only support fp16 as inputs. It should be trivial to extend to other types, but may require some work for mixed inputs + +### Overview of implementations + +Workflow of the tuning process +1. Generate the full tuning space. For now the `range`s for each tuning parameter are hard-coded +2. Prune the tuning space according to the current GEMM size and some rules + - BLOCK_SIZE must be equal or larger than the mfma instruction size. + - SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel. + - When split-k is not needed, i.e. both M and N are large, it must be 1 + - GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1 + - When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. **Needs further investigation** + - Skip BLOCK_SIZE_M or BLOCK_SIZE_N if they are over 2 times larger than M or N. +3. Open a file `generated_kernel{M}-{N}-{K}-{gpuid}.py` and write the following into the file + 1. For each config in the pruned space, generate a kernel with name `matmul_kernel_{configStr}`, where `configStr` contains the gemm size and the tuning parameters. + 2. Generate `matmul` function for each config in a similar way + 3. Generate `try_config` functions for each `matmul` function. + 4. Generate `test_gemm`, which does + 1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel. + 2. Call each `matmul` function in a for loop of 10 iterations + 5. Generate `main` function +4. Run the generated script with 16 workers. This will compile all kernels in parallel. +5. Invoke `rocprof` on the generated script +6. Post process `results.csv` by extract the execution time of the last instance of each kernel. Pick the best one, write to file, and return. + +## GEMM Tuning Script v3 + +### API changes + +- Input and output data types can be provided as `-dtype_a`, `-dtype_b`, and `-dtype_c`. +The provided types must be one of ['fp32', 'fp16', 'bf16', 'fp8', 'bf8', 'int8']. +- Row/col major-ness of operand a and b can be provided as `-col_a` and `-col_b`. +If set, it means the corresponding operand is column major. +The major-ness is considered as problem input. +So they should be included in the input yaml file. However, in the yaml file, user should +set `rowMajowA` and `rowMajorB` as shown in the example below. +- `--benchmark` is used to control if the perf config in the input yaml file is used as the tuning space. +- `--jobs` is used to control the number of .py files for generated kernels. +Note that this can be different from `ngpus`. This usually means multiple kernel files +will be profiled on each GPU. +This is necessary to keep each file "small" in terms of execution time. + +### Implementation changes +- `gen_input` is used to generate matmul inputs. +- Time measurement + - In benchmark mode, the kernel is executed 1000 times. + - In tuning mode, each kernel is executed 200 times. We cannot afford to larger runs since rocprof hangs if the session takes too long. + - In both tuning and benchmark mode, kernel time is measured as the average execution time of the last 100 instances. +- Added error recovery. This helps when rocprof crashes in multi-processing mode. + + + +## GEMM Tuning Script v3.1 + +### API changes + +- Added `matrix_instr_nonkdim` into the tuning space. Now we can tune mfma instruction size. + + +## GEMM Tuning Script v3.2 + +### API changes + +- Added `--rotating_tensor ` to use rotating memory blocks in each iteration, size in MB. Default is 0MB. +- Added `--icache_flush` to flush icache in each iteration. +Note, icache flush needs the module `python-hip`, which can be installed as: +`python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` +Rotating tensor and icache flush are to make perf numbers are closer to that in real applications. +- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, +so each element of the bias vector is added to all elements of the corresponding row of the output matrix.) + + +## GEMM Tuning Script v3.3 + +### API changes + +no API changes + +### Implementation changes + +- We use a dedicated file (named `get_filename_myKernels()`) to keep all the kernels +in the tuning space. +- Inside the for loop of tuning, each iteration tunes one gemm size + 1. Update kernel stage: Different gemm size may need different configs. We keep track + of the current tuning space. And if the current gemm size needs some configs that is + not included in the current tuning space, we expand the tuning space with the newly + added configs. + - This means if two gemm sizes share some configs, these configs will be compiled + once. This will greatly reduce batch tuning time. + 2. Compilation stage: + - We generate a single compilation driver file, named compile_driver.py (this is + obtained from `get_filename_compile_driver`) which contains the wrapper functions + of all the configs in the **pruned** tuning space for this gemm size. + - All the kernels will be compiled by 32 threads by default. Compiling all the + kernels in a single file in parallel is faster than splitting them into multiple + files. This can greatly reduce the compile time of the tuning process. + - Note that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. + 3. Profile stage + - We generate one task file per job, named `profile_driver_MxNxK_{job_id}.py` + (this is obtained from `get_filename_profile_driver`). The only difference is + that we no longer generate matmul_kernel in this file. Kernels are imported + from myKernels.py. +- `configStr` does not contain gemm size anymore. This allows the same matmul_{configStr} +kernel to be reused by different gemm sizes. +- `configStr` does not contain `_bias` if bias is provided. This is because we do not +expect to compare the same kernel w/ and w/o bias. Therefore, we treat bias in the same +way as gemm sizes. +- Add support for `EVEN_K` in the matmul kernel. Now the kernel support `BLOCK_SIZE_K` +that cannot divide `K`. +- Tuning result file is open and closed inside the tuning loop, enabling timely flush +of the tuning results. +- Now we use `rocprofv2` to measure kernel time. +- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation +stage. This is achieved by modifying the triton frontend compiler in the following +places: + - Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433) + - Return statically constructed GPUTarget from the `get_current_target()` + function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437) + - Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383) + - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) + + +# One config running script + +`one_config.py` is a script that runs one given matmul config. +It is an interface to `tune_gemm.py` functionality and could be used for triton debugging. + +## Usage + +This script supports two methods to specify configuration parameters. + +Variant 1: Separate command line attributes. + +```bash +python one_config.py -m 256 -n 256 -k 256 --block_m 64 --block_n 64 --block_k 64 --group_m 1 --split_k 2 --num_warps 2 --num_stages 0 --waves_per_eu 0 --matrix_instr_nonkdim 16 --kpack 2 +``` + +Variant 2: one-line config description. +This is how configs are printed by `tune_gemm.py` script + +```bash +python one_config.py --config_str M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16 +``` diff --git a/python/perf-kernels/tune_gemm/icache_flush.py b/python/perf-kernels/tune_gemm/icache_flush.py new file mode 100644 index 000000000000..320e746d30d4 --- /dev/null +++ b/python/perf-kernels/tune_gemm/icache_flush.py @@ -0,0 +1,94 @@ +import ctypes +import array +import random +import math + +# the hip module can be installed as +# `python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` +# more information about hip-python is at: https://github.com/ROCm/hip-python +from hip import hip, hiprtc + +def hip_check(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + elif ( + isinstance(err, hiprtc.hiprtcResult) + and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS + ): + raise RuntimeError(str(err)) + + return result + +# S_ICACHE_INV Invalidate entire first level instruction cache. +# There must be 16 separate S_NOP instructions or a jump/branch instruction +# after this instruction to ensure the internal instruction buffers are also +# invalidated. +def gen_kernel(): + source = b"""\ + extern "C" __global__ void icache_flush_kernel() { + asm __volatile__("s_icache_inv"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + asm __volatile__("s_nop 0"); + } + """ + + # print(f"source = {source}") + prog = hip_check(hiprtc.hiprtcCreateProgram(source, b"icache_flush_kernel", 0, [], [])) + progs = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(progs, 0)) + arch = progs.gcnArchName + cflags = [b"--offload-arch="+arch] + err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) + if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: + log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) + log = bytearray(log_size) + hip_check(hiprtc.hiprtcGetProgramLog(prog, log)) + print(f"log = {log.decode()}, err = {err}") + raise RuntimeError(log.decode()) + + code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog)) + code = bytearray(code_size) + hip_check(hiprtc.hiprtcGetCode(prog, code)) + module = hip_check(hip.hipModuleLoadData(code)) + kernel = hip_check(hip.hipModuleGetFunction(module, b"icache_flush_kernel")) + + return kernel + +kernel = gen_kernel() +progs = hip.hipDeviceProp_t() +hip_check(hip.hipGetDeviceProperties(progs, 0)) +cu_num = progs.multiProcessorCount + +def icache_flush(): + block = hip.dim3(x=64) + grid = hip.dim3(cu_num * 60) + + hip_check(hip.hipModuleLaunchKernel( + kernel, + *grid, + *block, + sharedMemBytes=0, + stream=None, + kernelParams=None, + extra=() + ) + ) diff --git a/python/perf-kernels/tune_gemm/matmul.py b/python/perf-kernels/tune_gemm/matmul.py new file mode 100644 index 000000000000..5b39d9330bff --- /dev/null +++ b/python/perf-kernels/tune_gemm/matmul.py @@ -0,0 +1,375 @@ +""" +Matrix Multiplication Tuning Scripts, Changed from the tutorial example "python/tutorials/03-matrix-multiplication.py" +""" + +import torch + +import triton +import triton.language as tl +import argparse +import sys +import yaml +import os +import subprocess + + + +# global flag to indicate whether using the full tuing space +tuning_full_space = True + +# pruned some unreasonable config +def prune_configs(configs, named_args): + # call only for full tuning space + if not tuning_full_space: + return configs + + SIZE_M = named_args["a_ptr"].shape[0] + SIZE_N = named_args["b_ptr"].shape[1] + SIZE_K = named_args["a_ptr"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K =\ + kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"] + SPLIT_K = kw["SPLIT_K"] + if SIZE_M <=32 and BLOCK_SIZE_M != 32: + continue + if SIZE_N <=32 and BLOCK_SIZE_N != 32: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): + continue + pruned_configs.append(config) + + return pruned_configs + + +def get_full_tuning_space(use_split_k): + configs = [] + if not tuning_full_space: + return configs + + block_mn_range = [32, 64, 128] + block_k_range = [32, 64] + split_k_range = [1, 2, 4, 5, 8, 10] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for split_k in split_k_range: + for num_stages in num_stage_range: + configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=num_stages, num_warps=num_warps)) + + return configs + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs= get_full_tuning_space(True) if tuning_full_space else [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 8}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 8}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=1), + ], + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': prune_configs, + 'perf_model': None, + "top_k": None + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def matmul_kernel_splitK( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + if SPLIT_K == 1: + offs_k = tl.arange(0, BLOCK_SIZE_K) + else: + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + if torch.version.hip is None: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + + grid_splitK = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel_splitK[grid_splitK]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + ACTIVATION=activation + ) + + return c + + +def test_correctness(M, N, K, datatype = torch.float16): + torch.manual_seed(0) + a = torch.randn((M, K), device='cuda', dtype=datatype) + b = torch.randn((K, N), device='cuda', dtype=datatype) + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f"triton_output={triton_output}") + print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + size_str = f'size, (M: {M}, N: {N}, K: {K})' + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print(f'✅ Triton and Torch match for {size_str}') + else: + print(f'❌ Triton and Torch differ for {size_str}') + + +def run_speed(M, N, K, datatype, use_rocprof, provider): + a = torch.randn((M, K), device='cuda', dtype=datatype) + b = torch.randn((K, N), device='cuda', dtype=datatype) + quantiles = [0.5, 0.2, 0.8] + if provider == 'pytorch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + return min_ms + +def run_bash_command(commandstring): + #print( commandstring ) + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE) + return proc.stdout.splitlines() + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-dtype", type=str, default='fp16', help="Input data type, default is fp16") + parser.add_argument("--specify_type", action='store_true', default=False, help="Whether user specify data type, default false") + parser.add_argument("--specify_size", action='store_true', default=False, help="Whether user specify input matrix size, default false") + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--rocprof", action='store_true', default=False, help='Use rocprof to measure kernel time, default uses do_bench()!') + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + dtype = torch.float16 + if args.specify_type: + if args.dtype == 'fp16': + dtype = torch.float16 + elif args.dtype == 'fp32': + dtype = torch.float32 + elif args.dtype == 'bf16': + dtype = torch.bfloat16 + else: + print(f"Unsupported datatype {args.dtype}") + sys.exit(1) + use_rocprof = args.rocprof + verbose = args.v + + mnks = [] + if args.specify_size: + M = args.m + N = args.n + K = args.k + if M == 0 or N == 0 or K == 0: + print(f"Input matrix size: (M {M}, N {N}, K {K}) contains dim size 0!") + mnks = [(M, N, K)] + else: + matrix_size_file = args.gemm_size_file + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + print(f"Matrix size file: {matrix_size_file} does not exist!") + sys.exit(1) + + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + + for sizes in matrix_sizes: + M = sizes['M'] + N = sizes['N'] + K = sizes['K'] + mnks.append((M, N, K)) + + + for (m, n, k) in mnks: + min_ms = run_speed(m, n, k, dtype, use_rocprof, 'triton') + + # function to compute flops + perf_flops = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) + + if args.compare: + test_correctness(m, n, k, dtype) + best_config = matmul_kernel_splitK.get_best_config() + + if use_rocprof: + dtype_str = 'fp16' if (not args.specify_type) else args.dtype + block_m = best_config.kwargs['BLOCK_SIZE_M'] + block_n = best_config.kwargs['BLOCK_SIZE_N'] + block_k = best_config.kwargs['BLOCK_SIZE_K'] + group_m = best_config.kwargs['GROUP_SIZE_M'] + split_k = best_config.kwargs['SPLIT_K'] + # num_warps = best_config['num_warps'] + num_warps = best_config.num_warps + driver = 'rocprof_gemm.py' + TRITON_DIR = os.getenv('TRITON_DIR') + if TRITON_DIR is not None: + driver = os.path.join(TRITON_DIR, 'python/perf-kernels/tune_gemm', driver) + run_cmd = f'python {driver} -m {m} -n {n} -k {k} \ + -block_m {block_m} -block_n {block_n} -block_k {block_k} \ + -group_m {group_m} -split_k {split_k} -num_warps {num_warps} \ + -dtype {dtype_str}' + prof_cmd = f'rocprof --stats {run_cmd}' + run_bash_command(prof_cmd) + + parse_result_cmd = f'sed -n \'/matmul_kernel/p\' results.stats.csv | awk -F \',\' \'{{print $4}}\'' + parse_outputs = run_bash_command(parse_result_cmd) + min_ms = int(parse_outputs[0]) / 1000000 + + out_str = f'SIZE: {m},{n},{k} ' + # print best config + if verbose: + out_str += f' best_config: ({best_config}), ' + out_str += f'TFLOPS: {perf_flops(min_ms)} time(ns): {min_ms * 1000000}' + print(out_str) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/matmul_kernel.py b/python/perf-kernels/tune_gemm/matmul_kernel.py new file mode 100644 index 000000000000..d5f854f3d8a1 --- /dev/null +++ b/python/perf-kernels/tune_gemm/matmul_kernel.py @@ -0,0 +1,64 @@ +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, bias_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, + EVEN_K: tl.constexpr +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if SPLIT_K == 1: + offs_k = tl.arange(0, BLOCK_SIZE_K) + else: + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + if BIAS: + bias_ptrs = bias_ptr + offs_am * stride_bias + bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) + acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + c = accumulator.to(c_ptr.type.element_ty) + if BIAS: + c += bias[:, None] + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) diff --git a/python/perf-kernels/tune_gemm/one_config.py b/python/perf-kernels/tune_gemm/one_config.py new file mode 100644 index 000000000000..52e6fba4a0a5 --- /dev/null +++ b/python/perf-kernels/tune_gemm/one_config.py @@ -0,0 +1,90 @@ +""" +Script for running one Matrix Multiplication kernel config at a time +""" + +import argparse +import re +import sys +import tune_gemm + +def parse_args(): + parser = argparse.ArgumentParser( + prog="check corectness of particular config for tuning gemm script", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') + parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') + parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") + parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") + parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") + parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])") + parser.add_argument("--bias_vector", action='store_true', default=False, help="apply bias vector") + parser.add_argument("--block_m", type=int, default=0) + parser.add_argument("--block_n", type=int, default=0) + parser.add_argument("--block_k", type=int, default=0) + parser.add_argument("--group_m", type=int, default=0) + parser.add_argument("--split_k", type=int, default=0) + parser.add_argument("--num_warps", type=int, default=0) + parser.add_argument("--num_stages", type=int, default=0) + parser.add_argument("--waves_per_eu", type=int, default=0) + parser.add_argument("--matrix_instr_nonkdim", type=int, default=0) + parser.add_argument("--kpack", type=int, default=0) + parser.add_argument("--config_str", type=str, default="", help="can take from tune_gemm.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16") + args = parser.parse_args() + + return args + + +def parse_config(cfg_str): + values = cfg_str.split("_") + config_name = {"M": "M", + "N": "N", + "K": "K", + "BM": "BLOCK_SIZE_M", + "BN": "BLOCK_SIZE_N", + "BK": "BLOCK_SIZE_K", + "GM": "GROUP_SIZE_M", + "SK": "SPLIT_K", + "nW": "num_warps", + "nS": "num_stages", + "EU": "waves_per_eu", + "kP": "kpack", + "mfma": "matrix_instr_nonkdim" + } + config = {} + for val in values: + match = re.search("([a-zA-Z]*)([0-9]*)", val) + if match: + cfg_field_name = config_name[match.group(1)] + config[cfg_field_name] = int(match.group(2)) + return config + + +def main(): + args = parse_args() + if args.config_str: + config = parse_config(args.config_str) + else: + config = {"M": args.m, + "N": args.n, + "K": args.k, + "BLOCK_SIZE_M": args.block_m, + "BLOCK_SIZE_N": args.block_n, + "BLOCK_SIZE_K": args.block_k, + "GROUP_SIZE_M": args.group_m, + "SPLIT_K": args.split_k, + "num_warps": args.num_warps, + "num_stages": args.num_stages, + "waves_per_eu": args.waves_per_eu, + "kpack": args.kpack, + "matrix_instr_nonkdim": args.matrix_instr_nonkdim + } + tune_gemm.test_correctness(config["M"], config["N"], config["K"], args.col_a, args.col_b, args.dtype_a, args.dtype_b, args.dtype_c, args.init_type, config, args.bias_vector, verbose=True) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/rocprof_gemm.py b/python/perf-kernels/tune_gemm/rocprof_gemm.py new file mode 100755 index 000000000000..8103fad554f7 --- /dev/null +++ b/python/perf-kernels/tune_gemm/rocprof_gemm.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +import argparse +import sys + +import torch +import triton +import triton.language as tl + + +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def matmul_kernel_splitK( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + if torch.version.hip is None: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# Kernel no split K +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = tl.arange(0, BLOCK_SIZE_K) + if torch.version.hip is None: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def matmul(a, b, block_m, block_n, block_k, group_m, split_k, num_warps, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + + if need_split_k(M, N, K): + grid_splitK = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'] + ) + matmul_kernel_splitK[grid_splitK]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + GROUP_SIZE_M = group_m, + SPLIT_K = split_k, + num_warps = num_warps, + num_stages = 1, + ACTIVATION=activation + ) + + else: + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M = block_m, + BLOCK_SIZE_N = block_n, + BLOCK_SIZE_K = block_k, + GROUP_SIZE_M = group_m, + num_warps = num_warps, + num_stages = 1, + ACTIVATION=activation + ) + + return c + + +def test_gemm(M, N, K, block_m, block_n, block_k, group_m, split_k, num_warps, dtype): + a = torch.randn((M, K), device='cuda', dtype=dtype) + b = torch.randn((K, N), device='cuda', dtype=dtype) + c = matmul(a, b, block_m, block_n, block_k, group_m, split_k, num_warps) + + return c + + +def main(args=None): + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser( + prog="test gemm tuning", + description="Tuning infra for triton gemm", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=argparse.SUPPRESS) + parser.add_argument("-n", type=int, default=argparse.SUPPRESS) + parser.add_argument("-k", type=int, default=argparse.SUPPRESS) + parser.add_argument("-block_m", type=int, default=argparse.SUPPRESS) + parser.add_argument("-block_n", type=int, default=argparse.SUPPRESS) + parser.add_argument("-block_k", type=int, default=argparse.SUPPRESS) + parser.add_argument("-group_m", type=int, default=argparse.SUPPRESS) + parser.add_argument("-split_k", type=int, default=argparse.SUPPRESS) + parser.add_argument("-num_warps", type=int, default=argparse.SUPPRESS) + parser.add_argument("-dtype", type=str, default='fp16', help="Input/output data type") + parsed_args = parser.parse_args(args) + + dtype = torch.float16 + if parsed_args.dtype == 'fp16': + dtype = torch.float16 + elif parsed_args.dtype == 'fp32': + dtype = torch.float32 + elif parsed_args.dtype == 'bf16': + dtype = torch.bfloat16 + else: + print(f"Unsupported datatype {args.dtype}") + sys.exit(1) + + M = parsed_args.m + N = parsed_args.n + K = parsed_args.k + block_m = parsed_args.block_m + block_n = parsed_args.block_n + block_k = parsed_args.block_k + group_m = parsed_args.group_m + split_k = parsed_args.split_k + num_warps = parsed_args.num_warps + test_gemm(M, N, K, block_m, block_n, block_k, group_m, split_k, num_warps, dtype) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tune_gemm/tune_gemm.py new file mode 100755 index 000000000000..3fdd7da082b5 --- /dev/null +++ b/python/perf-kernels/tune_gemm/tune_gemm.py @@ -0,0 +1,937 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import yaml +import os +import glob + +import torch +import triton +import triton.language as tl + +from matmul_kernel import matmul_kernel + +from datetime import datetime +import multiprocessing +import pandas as pd + +from utils.file_generator import * +from utils.utils import * + + +def is_hip_available(): + try: + __import__("hip") + except ImportError: + return False + else: + return True + + +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + 'BLOCK_SIZE_M': + block_m, + 'BLOCK_SIZE_N': + block_n, + 'BLOCK_SIZE_K': + block_k, + 'GROUP_SIZE_M': + group_m, + 'SPLIT_K': + split_k, + 'num_warps': + num_warps, + 'num_stages': + num_stages, + 'waves_per_eu': + waves_per_eu, + 'matrix_instr_nonkdim': + matrix_instr_nonkdim, + 'kpack': + kpack + }) + + return configs + + +def get_default_config(): + full_configs = get_full_tuning_space() + return full_configs[0] + + +def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): + pruned_configs = [] + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + num_stages = config.get("num_stages") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0 and SPLIT_K != 1: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + LDS = LDS if not num_stages else LDS * num_stages + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + # check if tiling is integer multiple of GEMM size because we have no boundary check + if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def extract_kernel_time(M, N, K, config, df, bias_size): + # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 + # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should + # not need below two lines + cols = [ + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', + 'tid', 'grd', 'wgr', 'lds', 'scr', 'arch_vgpr', 'accum_vgpr', 'sgpr', + 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + ] + df.columns = cols + configStr = gen_configStr(config) + filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() + filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] + meanTime = filtered_df['DurationNs'].tail(100).mean() + return config, meanTime + + +def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): + ngpus = len(gpus) + gpuIdx = gpus.index(gpuid) + if gpuIdx + 1 > jobs: + return + os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) + jobId = gpuIdx + while jobId < jobs: + kernel_name = get_filename_profile_driver(M, N, K, jobId) + if verbose: + print(f"profiling {kernel_name} on GPU {gpuid}") + run_bash_command_wrapper( + f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", + capture=(verbose < 2)) + jobId += ngpus + + +def tune_gemm_config(M, + N, + K, + col_a, + col_b, + dtype_a, + dtype_b, + dtype_c, + init_type, + configs, + run_bench, + jobs, + iters, + skipWarmup, + verbose=0, + num_threads=32, + gpus=[0], + rotating_buffer_size=256, + bias_size=0, + icache_flush=False): + + # precompile the kernels in parallel + start_time = datetime.now() + if not skipWarmup: + # Generate kernel out of all configs + fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, + dtype_b, dtype_c, init_type, configs, + rotating_buffer_size, bias_size) + + run_bash_command(f"python {fname} -n {num_threads}", + capture=(verbose < 2)) + compile_end = datetime.now() + compile_time = compile_end - start_time + if verbose: + print(f"compile time: {compile_time}", flush=True) + + # Generate kernels out of all configs + generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, jobs, iters, run_bench, + rotating_buffer_size, bias_size, icache_flush) + + # profile generated kernels + running = [ + multiprocessing.Process(target=profile_batch_kernels, + args=(M, N, K, gpu_id, gpus, jobs, verbose)) + for gpu_id in gpus + ] + for p in running: + p.start() + for p in running: + p.join() + + profile_end = datetime.now() + profile_time = profile_end - compile_end + if verbose: + print(f"profile time: {profile_time}", flush=True) + + # post process results.csv to get the best config and minTime + # TODO: process the file in parallel + minTime = 1024 * 1024 * 1024 + thread_pool = multiprocessing.Pool(processes=num_threads) + tasks = [] + idx = 0 + df_prof = [ + pd.read_csv(f"results_{i}.csv", + skiprows=1, + header=None, + delimiter=',', + quotechar='"', + escapechar='\\') for i in range(jobs) + ] + for config in configs: + file_idx = idx % jobs + tasks += [ + thread_pool.apply_async(extract_kernel_time, + args=(M, N, K, config, df_prof[file_idx], + bias_size)) + ] + idx += 1 + thread_pool.close() + thread_pool.join() + + for task in tasks: + config, myTime = task.get() + if myTime: + min_us = myTime / 1000 + if min_us < minTime: + minTime = min_us + bestConfig = config + else: + min_us = -1 + print( + f"invalid config(post processing): SIZE {M} {N} {K}: {config}", + flush=True) + post_end = datetime.now() + post_time = post_end - profile_end + if verbose: + print(f"post procesing time: {post_time}", flush=True) + return minTime, bestConfig, compile_time, profile_time, post_time + + +def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, + BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + def init_by_size_and_type(size, dtype, init_type): + if init_type == 'hpl': + return torch.empty(size, device='cuda', + dtype=dtype).uniform_(-0.5, 0.5) + # This init type has element[i] in row[j] equal to sin(i+j*N) + elif init_type == 'trig_float': + M, N = size + return torch.reshape(torch.arange(0, M * N), + (M, N)).sin().to(dtype=dtype, device='cuda') + elif init_type == 'zeros': + return torch.zeros(size, dtype=dtype, device='cuda') + elif init_type == "randn": + temp = torch.randn(size, dtype=dtype, device='cuda') + return temp + else: + raise ValueError("Bad matrix initialization type.") + + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), + torch.float32, init_type) + if needTrans: + raw_data = raw_data.T + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +# generate inputs/outputs according to rotating tensor size +def gen_rotating_tensors(M, + N, + K, + dtype_a, + need_Trans_a, + dtype_b, + need_Trans_b, + dtype_c, + seed, + init_type, + rotating_buffer_size, + bias_size, + device='cuda'): + a_size = M * K * type_name_to_bytes(dtype_a) + b_size = K * N * type_name_to_bytes(dtype_b) + c_size = M * N * type_name_to_bytes(dtype_c) + bias_size = bias_size * type_name_to_bytes(dtype_c) + + total_size = a_size + b_size + c_size + bias_size + block_count = rotating_buffer_size * 1024 * 1024 // total_size + block_count = max(1, block_count) + + # generate input and outputs + a = [] + b = [] + c = [] + bias = [] + for i in range(block_count): + in_a, in_a_fp16 = gen_input(M, + K, + dtype_a, + need_Trans_a, + 1, + init_type, + device='cuda') + a.append(in_a) + in_b, in_b_fp16 = gen_input(K, + N, + dtype_b, + need_Trans_b, + 2, + init_type, + device='cuda') + b.append(in_b) + out_c = torch.zeros((M, N), + dtype=tl_to_torch_types[name_to_tl_types[dtype_c]], + device='cuda') + c.append(out_c) + if bias_size > 0: + bs, bs_fp16 = gen_input(M, + 1, + dtype_b, + need_Trans_b, + 2, + init_type, + device='cuda') + bias.append(bs.squeeze()) + + in_outs = { + "rotating_num": block_count, + "input_a": a, + "input_b": b, + "output_c": c, + "bias": bias + } + + return in_outs + + +def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, + num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, + use_bias): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + #assert a.is_contiguous(), "Matrix A must be contiguous" + #assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + + grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k + stride_bias = bias.stride(0) if use_bias else 0 + EVEN_K = K % block_k == 0 + matmul_kernel[grid](a, + b, + c, + bias, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + stride_bias=stride_bias, + BLOCK_SIZE_M=block_m, + BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, + GROUP_SIZE_M=group_m, + SPLIT_K=split_k, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, + BIAS=use_bias, + EVEN_K=EVEN_K) + return c + + +def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, config, bias_vector, verbose): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + use_bias = bias_vector + torch.manual_seed(0) + #a = torch.randn((M, K), device='cuda', dtype=datatype) + #b = torch.randn((K, N), device='cuda', dtype=datatype) + a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda') + b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') + bias = None + if use_bias: + bias, bias_fp16 = gen_input(M, + 1, + dtype_b, + col_b, + 2, + init_type, + device='cuda') + bias = bias.squeeze() + bias_fp16 = bias.squeeze() + # Allocates output. + c = torch.zeros((M, N), + device=a.device, + dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, + split_k, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, use_bias) + torch_output = torch.matmul(a_fp16, b_fp16) + if use_bias: + torch_output += bias_fp16[:, None] + rtol = 0 if torch.version.hip is None else 1e-2 + atol = 1e-3 if split_k == 1 else 4e-2 + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = '' + if verbose: + size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' + if torch.allclose(triton_output.to(torch.float16), + torch_output, + atol=atol, + rtol=rtol): + print(f'{size_str} Correct✅') + else: + print(f"triton_output={triton_output}") + print(f"torch_output={torch_output}") + print(f'{size_str} Incorrect❌') + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-col_a", + action='store_true', + default=False, + help='whether matrix a is column major') + parser.add_argument("-col_b", + action='store_true', + default=False, + help='whether matrix b is column major') + parser.add_argument("-dtype_a", + type=str, + default='fp16', + help="matrix a element data type") + parser.add_argument("-dtype_b", + type=str, + default='fp16', + help="matrix b element data type") + parser.add_argument("-dtype_c", + type=str, + default='fp16', + help="output element data type") + parser.add_argument("--ngpus", + type=int, + default=0, + help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", + type=lambda s: [int(id) for id in s.split(',')], + default=[], + help='list of gpu ids to use for tuning') + parser.add_argument("--gemm_size_file", + type=str, + default="", + help='yaml file to indicate matrix size') + parser.add_argument("--o", + type=str, + default='', + help='yaml file to store tuning results') + parser.add_argument("--keep", + action='store_true', + default=False, + help='keep generated files') + parser.add_argument("--compare", + action='store_true', + default=False, + help="Whether check result correctness") + parser.add_argument( + "--compare_wo_tuning", + action='store_true', + default=False, + help="Whether check result correctness without tuning.") + parser.add_argument("--benchmark", + action='store_true', + default=False, + help="Benchmark the given config") + parser.add_argument( + "--time_breakdown", + action='store_true', + default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument( + "--verbose", + action='store_true', + default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument( + "--num_threads", + type=int, + default=32, + help= + "number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", + type=int, + default=1, + help="number of tasks during the profiling process") + parser.add_argument("--iters", + type=int, + default=1000, + help="number of iterations used in --benchmark mode") + parser.add_argument( + "--init_type", + type=str, + default='randn', + choices=['randn', 'hpl', 'trig_float', 'zeros'], + help="Input tensor initialization (default normal distribution)") + parser.add_argument( + "--rotating_tensor", + type=int, + default=0, + help="total size (MB) of all tensors (a, b, c, bias)." + " The default value is 0 (no rotating tensor)." + " When set, it needs to be larger than the L1, L2, MALL size)") + parser.add_argument("--bias_vector", + action='store_true', + default=False, + help="apply bias vector") + parser.add_argument("--icache_flush", + action='store_true', + default=False, + help="apply icache flush in tuning performance") + parser.add_argument("--no_warmup", + action='store_true', + default=False, + help="Whether we want to skip the compilation stage") + parser.add_argument("--hack_triton_compiler", + action='store_true', + default=False, + help="Modify the triton source to avoid backend query") + args = parser.parse_args() + if not args.o: + if args.benchmark: + args.o = "benchmarking_results.csv" + else: + args.o = get_default_tuning_result_filename() + + return args + + +def process_item(item): + M = item['M'] + N = item['N'] + K = item['K'] + col_a = False if item['rowMajorA'] == 'T' else True + col_b = False if item['rowMajorB'] == 'T' else True + del item['M'] + del item['N'] + del item['K'] + del item['rowMajorA'] + del item['rowMajorB'] + return M, N, K, col_a, col_b, item + + +def type_name_to_bytes(ty_name): + if '32' in ty_name: + return 4 + if '16' in ty_name: + return 2 + if '8' in ty_name: + return 1 + else: + print(f"Unrecognized input type name {ty_name}") + sys.exit(1) + + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted + + +def get_rocm_version(): + torch_hip_version = torch.version.hip + vers = torch_hip_version.split('.') + ret_ver = '$rocm_version' + if len(vers) >= 2: + ret_ver = vers[0] + '.' + vers[1] + return ret_ver + + +def main(): + args = parse_args() + matrix_size_file = args.gemm_size_file + output_file = args.o + keepTmp = args.keep + run_bench = args.benchmark + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + hack_triton = args.hack_triton_compiler + + # Get GPU ids + ngpus = args.ngpus + gpu_ids = args.gpu_ids + if ngpus != 0 and gpu_ids: + print("--ngpus and --gpu_ids are mutually exclusive options") + return os.EX_USAGE + if ngpus == 0 and not gpu_ids: + ngpus = 1 + if ngpus != 0: + gpus = range(ngpus) + if gpu_ids: + gpus = gpu_ids + + if run_bench: + gpus = [gpus[0]] + jobs = 1 + + # Get element type + dtype_a = args.dtype_a + dtype_b = args.dtype_b + dtype_c = args.dtype_c + if not dtype_a in name_to_tl_types or not dtype_b in name_to_tl_types or not dtype_c in name_to_tl_types: + print( + f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}" + ) + print("Supported types: ", list(name_to_tl_types.keys())) + sys.exit(1) + rotating_buffer_size = args.rotating_tensor + bias_vector = args.bias_vector + icache_flush = args.icache_flush + if icache_flush: + if not is_hip_available(): + print("************************************************************************************************") + print(" `icache-flush` is disabled for this run.") + print(" `icache-flush` needs python-hip module, which is unavailable.") + print(" python-hip module can be installed as:") + print(f" `python3 -m pip install -i https://test.pypi.org/simple hip-python~={get_rocm_version()}`") + print("************************************************************************************************") + icache_flush = False + + mnks = [] + # TODO: make it more robust to get user input + init_type = args.init_type + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + M = args.m + N = args.n + K = args.k + col_a = args.col_a + col_b = args.col_b + mnks = [(M, N, K, col_a, col_b, None)] + else: + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + for item in matrix_sizes: + M, N, K, col_a, col_b, item = process_item(item) + mnks.append((M, N, K, col_a, col_b, item)) + + # Check correctness from given configs + if args.compare_wo_tuning: + for (M, N, K, col_a, col_b, myConfig) in mnks: + if myConfig is None: + raise Exception( + "kernel config is None, need to provide a tuning config") + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, myConfig, bias_vector, True) + return + + configs_full = get_full_tuning_space() + + start_time = datetime.now() + # Append to the output file so that we can save all results into one file + f_results = open(output_file, 'a') + if run_bench: + print(f"Benchmarking gemm with {dtype_a} inputs") + print("trans M N K TFLOPS us") + f_results.write("trans,M,N,K,TFLOPS,us\n") + else: + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", + flush=True) + + f_results.close() + + ## Before tuning starts, clear cache and previously generated kernel files + run_bash_command("rm -rf ~/.triton/cache") + run_bash_command(f"rm -rf {get_filename_myKernels()}") + + ## Modify triton compiler + ## Hacky !!! + if hack_triton: + patch_triton_compiler() + + configs = [] + + ## Big for loop of tuning + ## Each iteration performs tuning for one gemm size + for (M, N, K, col_a, col_b, myConfig) in mnks: + + f_results = open(output_file, 'a') + + start_local_time = datetime.now() + # Obtain a pruned tuning space according to gemm size + # If running benchmark, use the provided config + pruned_configs = [myConfig] if run_bench else prune_configs( + M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) + + ## Only append new configs from the current gemm size + delta_configs = [ + config for config in pruned_configs if config not in configs + ] + configs += delta_configs + + ## Append new configs into the tuning space + generate_matmul_kernels(delta_configs) + + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' + if not run_bench: + print(f"{size_str} nConfigs: {len(pruned_configs)}", + end=" ", + flush=True) + else: + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", + end="") + f_results.write(f"{row_a_str}{row_b_str},{M},{N},{K},") + + # The main tuning funtion for one gemm size + verbose_level = 0 + if args.time_breakdown: + verbose_level = 1 + if args.verbose: + verbose_level = 2 + # we consider bias size as M for now. + bias_size = M if bias_vector else 0 + minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( + M, + N, + K, + col_a, + col_b, + dtype_a, + dtype_b, + dtype_c, + init_type, + pruned_configs, + run_bench, + jobs, + iters, + skipWarmup, + num_threads=args.num_threads, + gpus=gpus, + verbose=verbose_level, + rotating_buffer_size=rotating_buffer_size, + bias_size=bias_size, + icache_flush=icache_flush) + + # post processing the numbers + perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) + tri_tflops = perf_tflops(minTime) + formatted_tflops = format_output(tri_tflops) + minTime = format_output(minTime) + if not run_bench: + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', + end=" ", + flush=True) + + bestConfig_compact_str = gen_configStr(bestConfig) + if not run_bench: + print(f'best_config: {bestConfig_compact_str}', + end=" ", + flush=True) + + # write best config to tuning_results.yaml + if run_bench: + print(f"{formatted_tflops} {minTime}") + f_results.write(f"{formatted_tflops},{minTime}\n") + + sizeDict = { + 'M': M, + 'N': N, + 'K': K, + 'rowMajorA': row_a_str, + 'rowMajorB': row_b_str + } + sizeDict.update(bestConfig) + if not run_bench: + f_results.write("- " + str(sizeDict) + " ") + f_results.write( + f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + + # remove generated files if asked to + if not keepTmp: + if not skipWarmup: + os.remove(get_filename_compile_driver()) + try: + os.remove(get_filename_compile_driver() + + ".failed_configs") + except OSError: + pass + for i in range(jobs): + generated_script = get_filename_profile_driver(M, N, K, i) + os.remove(generated_script) + for f in glob.glob(f"results_{i}.*"): + os.remove(f) + + # Check correctness if asked to + if args.compare: + print("correctness: ", end=" ", flush=True) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, bestConfig, bias_vector, False) + elif not run_bench: + print("", flush=True) + + end_local_time = datetime.now() + if not run_bench: + print( + f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", + flush=True) + + f_results.close() + ## End big loop for tuning + + end_time = datetime.now() + tuning_time = end_time - start_time + if not run_bench: + print(f"Tuning ends at: {end_time}") + print(f"Total tuning time (h:m:s): {tuning_time}") + + if hack_triton: + print( + "Triton compiler is hacked, don't forget to git restore the changes :)" + ) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/tune_gemm.sh b/python/perf-kernels/tune_gemm/tune_gemm.sh new file mode 100755 index 000000000000..b49b57aa0aa3 --- /dev/null +++ b/python/perf-kernels/tune_gemm/tune_gemm.sh @@ -0,0 +1,27 @@ +#! /bin/bash + +## $1: driver program +## $2: M +## $3: N +## $4: K +## $5: 1: reduced tuning space + +if [[ $# -lt 4 ]];then + echo "Usage: ./tune_gemm.sh M N K" + exit +fi + +DRIVER=$1 +M=$2 +N=$3 +K=$4 +reduceSpace=$5 + +DRIVER=$(echo $DRIVER | sed -e "s/matmul_grouped.py/matmul.py/g") + +# $DRIVER is the actual tuning scripts, it is the file matmul.py +# -mnk are the size of input matrices, matrix (m, k) x (k, n) +# --specify_size means using -mnk to specify size of input matrices +# --rocprof means using rocprof to measure kernel time. If not set, +# kernel time is from do_bench() +python $DRIVER -m $M -n $N -k $K --specify_size --rocprof diff --git a/python/perf-kernels/tune_gemm/utils/file_generator.py b/python/perf-kernels/tune_gemm/utils/file_generator.py new file mode 100644 index 000000000000..eea92cf6bf48 --- /dev/null +++ b/python/perf-kernels/tune_gemm/utils/file_generator.py @@ -0,0 +1,355 @@ +import os +from .utils import * + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + split_k = config.get('SPLIT_K') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_configStr(config): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + ## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes + configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + return configStr + + +def generate_matmul_kernels(configs): + """ + Generate kernels based on configs and append them to get_filename_myKernels() + + Use the matmul_kernel template (../matmul_kernel.py) and append config to the + kernel name. E.g. matmul_kernel_BM256_BN256_BK64_GM1_SK1_nW1_nS0_EU0_kP2_mfma16() + """ + + if len(configs) == 0: + return + + f_kernel = open(get_filename_myKernels(), 'a') + + # write imports + import_str = """import triton +import triton.language as tl""" + f_kernel.write(import_str) + + with open( + os.path.dirname(os.path.abspath(__file__)) + + "/../matmul_kernel.py") as file: + matmul_kernel_code = file.read() + + for config in configs: + configStr = gen_configStr(config) + # Copy the matmul_kernel with name replaced + matmul_kernel_config = matmul_kernel_code.replace( + "matmul_kernel", f"matmul_kernel_{configStr}") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton.language as tl", "") + matmul_kernel_config = matmul_kernel_config.replace( + "import triton", "") + f_kernel.write(matmul_kernel_config) + + f_kernel.close() + + +## construct the configStr and generate the wrapper function matmul_{configStr}() +## If `warmup` is set, the generated kernel will be **compiled** +def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, + dtype_c, bias_size, warmup): + block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( + config) + + configStr = gen_configStr(config) + + use_bias = bias_size > 0 + + if warmup: + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + + matmul_def_str = f""" +def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + matmul_kernel_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS={use_bias}, + EVEN_K={EVEN_K}, + grid=(1,), + ) + return None + +def try_compile_config_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + try: + matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + else: + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): + grid = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}), {split_k} + matmul_kernel_{configStr}[grid]( + a, b, c, bias, + M, N, K, + am, ak, bk, bn, cm, cn, biasn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + SPLIT_K = {split_k}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + BIAS = {use_bias}, + EVEN_K = {EVEN_K} + ) + return c +""" + return configStr, matmul_def_str + + +def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, rotating_buffer_size, + bias_size): + """ + Generate a single file that contains all kernels in the tuning space. + This file is used to **compile** the kernels in parallel + """ + + filename = get_filename_compile_driver() + f_kernel = open(filename, 'w') + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + + f_kernel.write(import_str + "\n") + + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, True) + # Copy the matmul_kernel with name replaced + f_kernel.write(matmul_def_str + "\n") + + # write compile_kernels + # pre string + stride_a_str = "1, M" if col_a else "M, 1" + stride_b_str = "1, N" if col_b else "N, 1" + stride_c_str = "N, 1" + compile_kernels_pre_str = f"""def compile_kernels(M, N, K, rotating_buffer_size, bias_size, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + + assert bias_size == M or bias_size == 0 + + stride_bias = 1 if bias_size > 0 else 0 + stride_am, stride_ak = {stride_a_str} + stride_bk, stride_bn = {stride_b_str} + stride_cm, stride_cn = {stride_c_str} + task_args = (M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, stride_bias) + + results = [] + config_names = [] +""" + f_kernel.write(compile_kernels_pre_str + "\n") + + # warm up call of all matmul functions in parallel + for config in configs: + configStr = gen_configStr(config) + task_str = f" results += [thread_pool.apply_async(try_compile_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel.write(task_str) + + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + if failed_configs: + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") +""".format(filename=filename) + f_kernel.write(threadpool_str) + + # def main and call compile_kernels + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=32, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + compile_kernels_call_str = f'compile_kernels({M}, {N}, {K}, rotating_buffer_size, {bias_size}, numThreads)' + + f_kernel.write(def_main_str) + f_kernel.write(compile_kernels_call_str + "\n\n") + f_kernel.write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel.close() + + return filename + + +def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, + init_type, configs, jobs, iters, run_bench, + rotating_buffer_size, bias_size, icache_flush): + """ + Open {len(jobs)} files + generated_kernelM-N-K-0.py, generated_kernelM-N-K-1.py, ..., generated_kernelM-N-K-{njobs-1}.py + and generate + 1. matmul kernels of all configs + 2. wrapper function matmul to invoke all the generated kernels + 3. test_gemm to invoke matmul in a loop of {iters} iterations + """ + + filenames = [] + for i in range(jobs): + filenames.append(get_filename_profile_driver(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = f"""import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_gemm import gen_rotating_tensors +from {get_filename_without_extension(get_filename_myKernels())} import * +""" + if icache_flush: + import_str += """ +from icache_flush import icache_flush +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config( + config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, False) + # Copy the matmul_kernel with name replaced + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + test_gemm_pre_str = f"""def test_gemm(M, N, K, rotating_buffer_size, bias_size): + tensors = gen_rotating_tensors(M, N, K, '{dtype_a}', {col_a}, '{dtype_b}', {col_b}, '{dtype_c}', + 1, '{init_type}', rotating_buffer_size, bias_size, device='cuda') + + a = tensors['input_a'][0] + b = tensors['input_b'][0] + c = tensors['output_c'][0] + assert bias_size == M or bias_size == 0 + + stride_bias = tensors['bias'][0].stride(0) if bias_size > 0 else 0 + + try: + with open("{get_filename_compile_driver()}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + call_icache_flush = 'icache_flush()' if icache_flush else '' + for config in configs: + configStr = gen_configStr(config) + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + rotating_num = tensors['rotating_num'] + for i in range({runs}): + a = tensors['input_a'][i % rotating_num] + b = tensors['input_b'][i % rotating_num] + c = tensors['output_c'][i % rotating_num] + bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None + bias_stride = bias.stride(0) if bias_size > 0 else 0""" + if icache_flush: + matmul_call_str += f""" + icache_flush()""" + matmul_call_str += f""" + d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias_stride)""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = f""" +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + parser.add_argument("-rotating_tensor", type=int, default={rotating_buffer_size}, help='size of rotating buffer (MB), default: {rotating_buffer_size}') + args = parser.parse_args() + numThreads = args.n + rotating_buffer_size = args.rotating_tensor + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, rotating_buffer_size, {bias_size})' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() diff --git a/python/perf-kernels/tune_gemm/utils/utils.py b/python/perf-kernels/tune_gemm/utils/utils.py new file mode 100644 index 000000000000..9b6b50ea626b --- /dev/null +++ b/python/perf-kernels/tune_gemm/utils/utils.py @@ -0,0 +1,115 @@ +import torch +import triton +import triton.language as tl + +import os +import subprocess +from datetime import datetime + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError as e: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, + shell=True, + check=True, + executable='/bin/bash', + stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, + shell=True, + check=True, + executable='/bin/bash') + return None + + +def get_filename_myKernels(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../myKernels.py" + + +def get_filename_without_extension(file_path): + base_name = os.path.basename(file_path) + file_name, _ = os.path.splitext(base_name) + return file_name + + +def get_filename_compile_driver(): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../compile_driver.py" + + +def get_filename_profile_driver(M, N, K, job_id): + path = os.path.dirname(os.path.abspath(__file__)) + return f"{path}/../profile_driver_{M}x{N}x{K}_{job_id}.py" + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + # handle branch name of "xxx/xxx" format + git_branch_name = git_branch_name.replace('/', '_') + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + + path = os.path.dirname(os.path.abspath(__file__)) + defaultName = f"{path}/../tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def patch_triton_compiler(): + device = triton.runtime.driver.active.get_current_device() + stream = triton.runtime.driver.active.get_current_stream(device) + target = triton.runtime.driver.active.get_current_target() + + triton_location_str = run_bash_command("pip show triton | grep Editable") + if not triton_location_str: + print("triton source not found from pip show triton") + + triton_dir = triton_location_str[0].split()[-1].decode('utf-8') + + jit_filename = os.path.join(triton_dir, "triton/runtime", "jit.py") + + run_bash_command(f"sed -i 's/driver.active.get_current_device()/{device}/g' {jit_filename}") + run_bash_command(f"sed -i 's/driver.active.get_current_stream(device)/{stream}/g' {jit_filename}") + + hip_driver_filename = os.path.join(triton_dir, "../third_party/amd/backend/", "driver.py") + cuda_driver_filename = os.path.join(triton_dir, "../third_party/nvidia/backend/", "driver.py") + + run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}") + run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}") + run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}") From 624335ff569562d5db26bea337e3c6de2bd6b0dc Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 16 Aug 2024 15:42:27 -0300 Subject: [PATCH 11/20] Clean up *tune_gemm* script from `main_perf` branch (#629) * Reformat *tune_gemm* files with Triton's pre-commit The following command was executed to reformat the files: ``` $ pre-commit run --files \ python/perf-kernels/tune_gemm/* \ python/perf-kernels/tune_gemm/utils/* ``` * Fix *tune_gemm* issue with (1, 1) bias tensors * Fix `ruff` F405 errors Fix the following linter error: F405 `identifier` may be undefined, or defined from star imports * Fix `ruff` F841 errors Fix the following linter error: F841 Local variable `identifier` is assigned to but never used * Fix minor issues in README file * Add `--` to `num_threads` argument. * Replace `--icahe` argument (non-existent argument) with `--icache_flush` (existent argument). * Remove old files from *tune_gemm* V1 * Add dependency graph to README file * Selectively disable `yapf` for parts of `one_config.py` --- python/perf-kernels/tune_gemm/README.md | 50 +- python/perf-kernels/tune_gemm/icache_flush.py | 28 +- python/perf-kernels/tune_gemm/matmul.py | 375 --------------- .../perf-kernels/tune_gemm/matmul_kernel.py | 15 +- python/perf-kernels/tune_gemm/one_config.py | 74 +-- python/perf-kernels/tune_gemm/rocprof_gemm.py | 318 ------------- python/perf-kernels/tune_gemm/tune_gemm.py | 447 +++++------------- python/perf-kernels/tune_gemm/tune_gemm.sh | 27 -- .../tune_gemm/utils/file_generator.py | 43 +- python/perf-kernels/tune_gemm/utils/utils.py | 17 +- 10 files changed, 229 insertions(+), 1165 deletions(-) delete mode 100644 python/perf-kernels/tune_gemm/matmul.py delete mode 100755 python/perf-kernels/tune_gemm/rocprof_gemm.py delete mode 100755 python/perf-kernels/tune_gemm/tune_gemm.sh diff --git a/python/perf-kernels/tune_gemm/README.md b/python/perf-kernels/tune_gemm/README.md index 5a986a9f987d..da45dcda5c3c 100644 --- a/python/perf-kernels/tune_gemm/README.md +++ b/python/perf-kernels/tune_gemm/README.md @@ -12,15 +12,15 @@ This means `BLOCK_SIZE_K` does not need to divide K dim. ### Differences between the tutorial -Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial), +Unlike the [matmul tutorial](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py) (referred as the tutorial), the matmul kernel used in the tuning script (referred as the kernel) does not -guard load along M and N dim +guard load along M and N dim ([this](https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py#L282-L283) shows how this is done in the tutorial). -When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will +When `BLOCK_SIZE_M` or `BLOCK_SIZE_N` does not divide M or N, the kernel will load out-of-bound data. In most cases this is fine, since the kernel does masked store at the end. However, this may lead to GPU memory access fault in some cases, especially -when the tensor is large. +when the tensor is large. We will fix this issue in the future. @@ -45,7 +45,7 @@ The following `options` are supported in the tuning mode - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: input and output element type. - Supported `dtype`: fp16 (default), bf16, fp8, bf8, int8, int32, fp32 - Parallel compilation of kernels: - - `num_threads n` controls that n threads will + - `--num_threads n` controls that n threads will be used in the compilation stage. The default value is 32. - `--no_warmup` can be used to skip the compilation stage. Thus kernels will be compiled during the profiling stage. This increases tuning time. But it's @@ -53,7 +53,7 @@ The following `options` are supported in the tuning mode kernel launch is not supported. - Parallel profiling of kernels: The tuning space is first divided into a number of tasks, which is controlled by `--jobs n`. And all the tasks can be profiled in -parallel on a number of GPUs in the system. There are two ways to specify which +parallel on a number of GPUs in the system. There are two ways to specify which GPU(s) we want to use for profiling. Note that these flags cannot be use together. By default, only one task is generated and profiled on GPU0. - `--ngpus n`: GPU 0,1,.., n-1 will be used. @@ -136,9 +136,9 @@ The supported `options` are as followings - `-dtype_a dtype`, `-dtype_b dtype`, and `-dtype_c dtype`: same as tuning mode. - `--iters n` controls the number of iterations to run the kernel. The default value is 1000. -- `--icahe`: same as tuning mode +- `--icache_flush`: same as tuning mode - `--rotating_tensor SIZE`: same as tuning mode - + ## Tuning script implementation overview @@ -150,6 +150,22 @@ task. This will save invocation overhead of the profiler. For detailed implementation, please refer to the changelog of each version. +### Dependency graph + +The following graph depicts the dependency between Python modules: +```mermaid +graph TD; + one_config.py --> tune_gemm.py + tune_gemm.py --> matmul_kernel.py + tune_gemm.py --> utils/file_generator.py + tune_gemm.py --> utils/utils.py + utils/file_generator.py --> utils/utils.py + utils/file_generator.py -.-> icache_flush.py +``` + +`utils/file_generator.py` doesn't import `icache_flush.py` but it generates kernels that can import +`icache_flush.py`. + # Changelog @@ -178,7 +194,7 @@ Workflow of the tuning process 1. Generate the full tuning space. For now the `range`s for each tuning parameter are hard-coded 2. Prune the tuning space according to the current GEMM size and some rules - BLOCK_SIZE must be equal or larger than the mfma instruction size. - - SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel. + - SPLIT_K * BLOCK_SIZE_K must divide K. Therefore, we do not need EVEN_K in the kernel. - When split-k is not needed, i.e. both M and N are large, it must be 1 - GROUP_M * BLOCK_SIZE_M must be smaller than M. Otherwise, GROUP_M must be 1 - When BLOCK_SIZE_K = 128, neither BLOCK_SIZE_M or BLOCK_SIZE_N can be 128. Otherwise too much LDS will be required. **Needs further investigation** @@ -188,7 +204,7 @@ Workflow of the tuning process 2. Generate `matmul` function for each config in a similar way 3. Generate `try_config` functions for each `matmul` function. 4. Generate `test_gemm`, which does - 1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel. + 1. Add all `try_config` functions in the thread_pool by `thread_pool.apply_async(try_config)`. This is used to compile all kernels in parallel. 2. Call each `matmul` function in a for loop of 10 iterations 5. Generate `main` function 4. Run the generated script with 16 workers. This will compile all kernels in parallel. @@ -203,7 +219,7 @@ Workflow of the tuning process The provided types must be one of ['fp32', 'fp16', 'bf16', 'fp8', 'bf8', 'int8']. - Row/col major-ness of operand a and b can be provided as `-col_a` and `-col_b`. If set, it means the corresponding operand is column major. -The major-ness is considered as problem input. +The major-ness is considered as problem input. So they should be included in the input yaml file. However, in the yaml file, user should set `rowMajowA` and `rowMajorB` as shown in the example below. - `--benchmark` is used to control if the perf config in the input yaml file is used as the tuning space. @@ -218,7 +234,7 @@ This is necessary to keep each file "small" in terms of execution time. - In benchmark mode, the kernel is executed 1000 times. - In tuning mode, each kernel is executed 200 times. We cannot afford to larger runs since rocprof hangs if the session takes too long. - In both tuning and benchmark mode, kernel time is measured as the average execution time of the last 100 instances. -- Added error recovery. This helps when rocprof crashes in multi-processing mode. +- Added error recovery. This helps when rocprof crashes in multi-processing mode. @@ -233,12 +249,12 @@ This is necessary to keep each file "small" in terms of execution time. ### API changes -- Added `--rotating_tensor ` to use rotating memory blocks in each iteration, size in MB. Default is 0MB. -- Added `--icache_flush` to flush icache in each iteration. +- Added `--rotating_tensor ` to use rotating memory blocks in each iteration, size in MB. Default is 0MB. +- Added `--icache_flush` to flush icache in each iteration. Note, icache flush needs the module `python-hip`, which can be installed as: `python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` Rotating tensor and icache flush are to make perf numbers are closer to that in real applications. -- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, +- Added `--bias_vector` to support kernel execution with bias (bias vector is of the same size as the number of rows of the output matrix, so each element of the bias vector is added to all elements of the corresponding row of the output matrix.) @@ -283,11 +299,11 @@ that cannot divide `K`. - Tuning result file is open and closed inside the tuning loop, enabling timely flush of the tuning results. - Now we use `rocprofv2` to measure kernel time. -- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation +- We can use `--hack_triton_compile` to avoid all GPU activities during the compilation stage. This is achieved by modifying the triton frontend compiler in the following places: - Return True from the `is_active()` function in the hip hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L433) - - Return statically constructed GPUTarget from the `get_current_target()` + - Return statically constructed GPUTarget from the `get_current_target()` function in the hip backend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/amd/backend/driver.py#L437) - Return False from the `is_active()` function in the cuda hackend [driver](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/third_party/nvidia/backend/driver.py#L383) - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) diff --git a/python/perf-kernels/tune_gemm/icache_flush.py b/python/perf-kernels/tune_gemm/icache_flush.py index 320e746d30d4..6b9c0359c381 100644 --- a/python/perf-kernels/tune_gemm/icache_flush.py +++ b/python/perf-kernels/tune_gemm/icache_flush.py @@ -1,13 +1,9 @@ -import ctypes -import array -import random -import math - # the hip module can be installed as # `python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version` # more information about hip-python is at: https://github.com/ROCm/hip-python from hip import hip, hiprtc + def hip_check(call_result): err = call_result[0] result = call_result[1:] @@ -16,14 +12,12 @@ def hip_check(call_result): if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: raise RuntimeError(str(err)) - elif ( - isinstance(err, hiprtc.hiprtcResult) - and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS - ): + elif (isinstance(err, hiprtc.hiprtcResult) and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS): raise RuntimeError(str(err)) return result + # S_ICACHE_INV Invalidate entire first level instruction cache. # There must be 16 separate S_NOP instructions or a jump/branch instruction # after this instruction to ensure the internal instruction buffers are also @@ -56,7 +50,7 @@ def gen_kernel(): progs = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(progs, 0)) arch = progs.gcnArchName - cflags = [b"--offload-arch="+arch] + cflags = [b"--offload-arch=" + arch] err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) @@ -73,22 +67,16 @@ def gen_kernel(): return kernel + kernel = gen_kernel() progs = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(progs, 0)) cu_num = progs.multiProcessorCount + def icache_flush(): block = hip.dim3(x=64) grid = hip.dim3(cu_num * 60) - hip_check(hip.hipModuleLaunchKernel( - kernel, - *grid, - *block, - sharedMemBytes=0, - stream=None, - kernelParams=None, - extra=() - ) - ) + hip_check( + hip.hipModuleLaunchKernel(kernel, *grid, *block, sharedMemBytes=0, stream=None, kernelParams=None, extra=())) diff --git a/python/perf-kernels/tune_gemm/matmul.py b/python/perf-kernels/tune_gemm/matmul.py deleted file mode 100644 index 5b39d9330bff..000000000000 --- a/python/perf-kernels/tune_gemm/matmul.py +++ /dev/null @@ -1,375 +0,0 @@ -""" -Matrix Multiplication Tuning Scripts, Changed from the tutorial example "python/tutorials/03-matrix-multiplication.py" -""" - -import torch - -import triton -import triton.language as tl -import argparse -import sys -import yaml -import os -import subprocess - - - -# global flag to indicate whether using the full tuing space -tuning_full_space = True - -# pruned some unreasonable config -def prune_configs(configs, named_args): - # call only for full tuning space - if not tuning_full_space: - return configs - - SIZE_M = named_args["a_ptr"].shape[0] - SIZE_N = named_args["b_ptr"].shape[1] - SIZE_K = named_args["a_ptr"].shape[1] - - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K =\ - kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"] - SPLIT_K = kw["SPLIT_K"] - if SIZE_M <=32 and BLOCK_SIZE_M != 32: - continue - if SIZE_N <=32 and BLOCK_SIZE_N != 32: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): - continue - pruned_configs.append(config) - - return pruned_configs - - -def get_full_tuning_space(use_split_k): - configs = [] - if not tuning_full_space: - return configs - - block_mn_range = [32, 64, 128] - block_k_range = [32, 64] - split_k_range = [1, 2, 4, 5, 8, 10] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for split_k in split_k_range: - for num_stages in num_stage_range: - configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=num_stages, num_warps=num_warps)) - - return configs - - -# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: -# - A list of `triton.Config` objects that define different configurations of -# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try -# - An auto-tuning *key* whose change in values will trigger evaluation of all the -# provided configs -@triton.autotune( - configs= get_full_tuning_space(True) if tuning_full_space else [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'SPLIT_K': 1}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 8}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 8}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=1), - ], - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': prune_configs, - 'perf_model': None, - "top_k": None - }, -) -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0, -}) -@triton.jit -def matmul_kernel_splitK( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ACTIVATION: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - pid_z = tl.program_id(1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - if SPLIT_K == 1: - offs_k = tl.arange(0, BLOCK_SIZE_K) - else: - offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - if torch.version.hip is None: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - else: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) - a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - if ACTIVATION == "leaky_relu": - accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. -@triton.jit -def leaky_relu(x): - x = x + 1 - return tl.where(x >= 0, x, 0.01 * x) - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def matmul(a, b, activation=""): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - # 1D launch kernel where each block gets its own program. - - grid_splitK = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'] - ) - matmul_kernel_splitK[grid_splitK]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation - ) - - return c - - -def test_correctness(M, N, K, datatype = torch.float16): - torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=datatype) - b = torch.randn((K, N), device='cuda', dtype=datatype) - triton_output = matmul(a, b) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - rtol = 0 if torch.version.hip is None else 1e-2 - size_str = f'size, (M: {M}, N: {N}, K: {K})' - if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): - print(f'✅ Triton and Torch match for {size_str}') - else: - print(f'❌ Triton and Torch differ for {size_str}') - - -def run_speed(M, N, K, datatype, use_rocprof, provider): - a = torch.randn((M, K), device='cuda', dtype=datatype) - b = torch.randn((K, N), device='cuda', dtype=datatype) - quantiles = [0.5, 0.2, 0.8] - if provider == 'pytorch': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) - return min_ms - -def run_bash_command(commandstring): - #print( commandstring ) - proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout = subprocess.PIPE) - return proc.stdout.splitlines() - - -def parse_args(): - parser = argparse.ArgumentParser( - prog="tune a specific gemm size", - allow_abbrev=False, - ) - - parser.add_argument("-m", type=int, default=0) - parser.add_argument("-n", type=int, default=0) - parser.add_argument("-k", type=int, default=0) - parser.add_argument("-dtype", type=str, default='fp16', help="Input data type, default is fp16") - parser.add_argument("--specify_type", action='store_true', default=False, help="Whether user specify data type, default false") - parser.add_argument("--specify_size", action='store_true', default=False, help="Whether user specify input matrix size, default false") - parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") - parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') - parser.add_argument("--rocprof", action='store_true', default=False, help='Use rocprof to measure kernel time, default uses do_bench()!') - parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") - args = parser.parse_args() - - return args - -def main(): - args = parse_args() - dtype = torch.float16 - if args.specify_type: - if args.dtype == 'fp16': - dtype = torch.float16 - elif args.dtype == 'fp32': - dtype = torch.float32 - elif args.dtype == 'bf16': - dtype = torch.bfloat16 - else: - print(f"Unsupported datatype {args.dtype}") - sys.exit(1) - use_rocprof = args.rocprof - verbose = args.v - - mnks = [] - if args.specify_size: - M = args.m - N = args.n - K = args.k - if M == 0 or N == 0 or K == 0: - print(f"Input matrix size: (M {M}, N {N}, K {K}) contains dim size 0!") - mnks = [(M, N, K)] - else: - matrix_size_file = args.gemm_size_file - if matrix_size_file == "" or not os.path.isfile(matrix_size_file): - print(f"Matrix size file: {matrix_size_file} does not exist!") - sys.exit(1) - - with open(matrix_size_file) as file: - matrix_sizes = yaml.safe_load(file) - - for sizes in matrix_sizes: - M = sizes['M'] - N = sizes['N'] - K = sizes['K'] - mnks.append((M, N, K)) - - - for (m, n, k) in mnks: - min_ms = run_speed(m, n, k, dtype, use_rocprof, 'triton') - - # function to compute flops - perf_flops = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - - if args.compare: - test_correctness(m, n, k, dtype) - best_config = matmul_kernel_splitK.get_best_config() - - if use_rocprof: - dtype_str = 'fp16' if (not args.specify_type) else args.dtype - block_m = best_config.kwargs['BLOCK_SIZE_M'] - block_n = best_config.kwargs['BLOCK_SIZE_N'] - block_k = best_config.kwargs['BLOCK_SIZE_K'] - group_m = best_config.kwargs['GROUP_SIZE_M'] - split_k = best_config.kwargs['SPLIT_K'] - # num_warps = best_config['num_warps'] - num_warps = best_config.num_warps - driver = 'rocprof_gemm.py' - TRITON_DIR = os.getenv('TRITON_DIR') - if TRITON_DIR is not None: - driver = os.path.join(TRITON_DIR, 'python/perf-kernels/tune_gemm', driver) - run_cmd = f'python {driver} -m {m} -n {n} -k {k} \ - -block_m {block_m} -block_n {block_n} -block_k {block_k} \ - -group_m {group_m} -split_k {split_k} -num_warps {num_warps} \ - -dtype {dtype_str}' - prof_cmd = f'rocprof --stats {run_cmd}' - run_bash_command(prof_cmd) - - parse_result_cmd = f'sed -n \'/matmul_kernel/p\' results.stats.csv | awk -F \',\' \'{{print $4}}\'' - parse_outputs = run_bash_command(parse_result_cmd) - min_ms = int(parse_outputs[0]) / 1000000 - - out_str = f'SIZE: {m},{n},{k} ' - # print best config - if verbose: - out_str += f' best_config: ({best_config}), ' - out_str += f'TFLOPS: {perf_flops(min_ms)} time(ns): {min_ms * 1000000}' - print(out_str) - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/matmul_kernel.py b/python/perf-kernels/tune_gemm/matmul_kernel.py index d5f854f3d8a1..336a643dca50 100644 --- a/python/perf-kernels/tune_gemm/matmul_kernel.py +++ b/python/perf-kernels/tune_gemm/matmul_kernel.py @@ -3,17 +3,10 @@ @triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, bias_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_bias, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, - EVEN_K: tl.constexpr -): +def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, + stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, + EVEN_K: tl.constexpr): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) diff --git a/python/perf-kernels/tune_gemm/one_config.py b/python/perf-kernels/tune_gemm/one_config.py index 52e6fba4a0a5..5354a270f493 100644 --- a/python/perf-kernels/tune_gemm/one_config.py +++ b/python/perf-kernels/tune_gemm/one_config.py @@ -7,6 +7,7 @@ import sys import tune_gemm + def parse_args(): parser = argparse.ArgumentParser( prog="check corectness of particular config for tuning gemm script", @@ -21,7 +22,8 @@ def parse_args(): parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") - parser.add_argument("--init_type", type=str, default='randn', help="Initialization type for input matrices (default uniform rand [0, 1.0)])") + parser.add_argument("--init_type", type=str, default='randn', + help="Initialization type for input matrices (default uniform rand [0, 1.0)])") parser.add_argument("--bias_vector", action='store_true', default=False, help="apply bias vector") parser.add_argument("--block_m", type=int, default=0) parser.add_argument("--block_n", type=int, default=0) @@ -33,7 +35,10 @@ def parse_args(): parser.add_argument("--waves_per_eu", type=int, default=0) parser.add_argument("--matrix_instr_nonkdim", type=int, default=0) parser.add_argument("--kpack", type=int, default=0) - parser.add_argument("--config_str", type=str, default="", help="can take from tune_gemm.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16") + parser.add_argument( + "--config_str", type=str, default="", help= + "can take from tune_gemm.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP2_mfma16" + ) args = parser.parse_args() return args @@ -41,20 +46,23 @@ def parse_args(): def parse_config(cfg_str): values = cfg_str.split("_") - config_name = {"M": "M", - "N": "N", - "K": "K", - "BM": "BLOCK_SIZE_M", - "BN": "BLOCK_SIZE_N", - "BK": "BLOCK_SIZE_K", - "GM": "GROUP_SIZE_M", - "SK": "SPLIT_K", - "nW": "num_warps", - "nS": "num_stages", - "EU": "waves_per_eu", - "kP": "kpack", - "mfma": "matrix_instr_nonkdim" - } + # yapf: disable + config_name = { + "M": "M", + "N": "N", + "K": "K", + "BM": "BLOCK_SIZE_M", + "BN": "BLOCK_SIZE_N", + "BK": "BLOCK_SIZE_K", + "GM": "GROUP_SIZE_M", + "SK": "SPLIT_K", + "nW": "num_warps", + "nS": "num_stages", + "EU": "waves_per_eu", + "kP": "kpack", + "mfma": "matrix_instr_nonkdim", + } + # yapf: enable config = {} for val in values: match = re.search("([a-zA-Z]*)([0-9]*)", val) @@ -69,21 +77,25 @@ def main(): if args.config_str: config = parse_config(args.config_str) else: - config = {"M": args.m, - "N": args.n, - "K": args.k, - "BLOCK_SIZE_M": args.block_m, - "BLOCK_SIZE_N": args.block_n, - "BLOCK_SIZE_K": args.block_k, - "GROUP_SIZE_M": args.group_m, - "SPLIT_K": args.split_k, - "num_warps": args.num_warps, - "num_stages": args.num_stages, - "waves_per_eu": args.waves_per_eu, - "kpack": args.kpack, - "matrix_instr_nonkdim": args.matrix_instr_nonkdim - } - tune_gemm.test_correctness(config["M"], config["N"], config["K"], args.col_a, args.col_b, args.dtype_a, args.dtype_b, args.dtype_c, args.init_type, config, args.bias_vector, verbose=True) + # yapf: disable + config = { + "M": args.m, + "N": args.n, + "K": args.k, + "BLOCK_SIZE_M": args.block_m, + "BLOCK_SIZE_N": args.block_n, + "BLOCK_SIZE_K": args.block_k, + "GROUP_SIZE_M": args.group_m, + "SPLIT_K": args.split_k, + "num_warps": args.num_warps, + "num_stages": args.num_stages, + "waves_per_eu": args.waves_per_eu, + "kpack": args.kpack, + "matrix_instr_nonkdim": args.matrix_instr_nonkdim, + } + # yapf: enable + tune_gemm.test_correctness(config["M"], config["N"], config["K"], args.col_a, args.col_b, args.dtype_a, + args.dtype_b, args.dtype_c, args.init_type, config, args.bias_vector, verbose=True) if __name__ == "__main__": diff --git a/python/perf-kernels/tune_gemm/rocprof_gemm.py b/python/perf-kernels/tune_gemm/rocprof_gemm.py deleted file mode 100755 index 8103fad554f7..000000000000 --- a/python/perf-kernels/tune_gemm/rocprof_gemm.py +++ /dev/null @@ -1,318 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import sys - -import torch -import triton -import triton.language as tl - - -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0, -}) -@triton.jit -def matmul_kernel_splitK( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ACTIVATION: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - pid_z = tl.program_id(1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - if torch.version.hip is None: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - else: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) - a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - if ACTIVATION == "leaky_relu": - accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -# Kernel no split K -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, -}) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - offs_k = tl.arange(0, BLOCK_SIZE_K) - if torch.version.hip is None: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - else: - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - if ACTIVATION == "leaky_relu": - accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. -@triton.jit -def leaky_relu(x): - x = x + 1 - return tl.where(x >= 0, x, 0.01 * x) - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def matmul(a, b, block_m, block_n, block_k, group_m, split_k, num_warps, activation=""): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - # 1D launch kernel where each block gets its own program. - - if need_split_k(M, N, K): - grid_splitK = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'] - ) - matmul_kernel_splitK[grid_splitK]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_SIZE_M = block_m, - BLOCK_SIZE_N = block_n, - BLOCK_SIZE_K = block_k, - GROUP_SIZE_M = group_m, - SPLIT_K = split_k, - num_warps = num_warps, - num_stages = 1, - ACTIVATION=activation - ) - - else: - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_SIZE_M = block_m, - BLOCK_SIZE_N = block_n, - BLOCK_SIZE_K = block_k, - GROUP_SIZE_M = group_m, - num_warps = num_warps, - num_stages = 1, - ACTIVATION=activation - ) - - return c - - -def test_gemm(M, N, K, block_m, block_n, block_k, group_m, split_k, num_warps, dtype): - a = torch.randn((M, K), device='cuda', dtype=dtype) - b = torch.randn((K, N), device='cuda', dtype=dtype) - c = matmul(a, b, block_m, block_n, block_k, group_m, split_k, num_warps) - - return c - - -def main(args=None): - if args is None: - args = sys.argv[1:] - - parser = argparse.ArgumentParser( - prog="test gemm tuning", - description="Tuning infra for triton gemm", - allow_abbrev=False, - ) - - parser.add_argument("-m", type=int, default=argparse.SUPPRESS) - parser.add_argument("-n", type=int, default=argparse.SUPPRESS) - parser.add_argument("-k", type=int, default=argparse.SUPPRESS) - parser.add_argument("-block_m", type=int, default=argparse.SUPPRESS) - parser.add_argument("-block_n", type=int, default=argparse.SUPPRESS) - parser.add_argument("-block_k", type=int, default=argparse.SUPPRESS) - parser.add_argument("-group_m", type=int, default=argparse.SUPPRESS) - parser.add_argument("-split_k", type=int, default=argparse.SUPPRESS) - parser.add_argument("-num_warps", type=int, default=argparse.SUPPRESS) - parser.add_argument("-dtype", type=str, default='fp16', help="Input/output data type") - parsed_args = parser.parse_args(args) - - dtype = torch.float16 - if parsed_args.dtype == 'fp16': - dtype = torch.float16 - elif parsed_args.dtype == 'fp32': - dtype = torch.float32 - elif parsed_args.dtype == 'bf16': - dtype = torch.bfloat16 - else: - print(f"Unsupported datatype {args.dtype}") - sys.exit(1) - - M = parsed_args.m - N = parsed_args.n - K = parsed_args.k - block_m = parsed_args.block_m - block_n = parsed_args.block_n - block_k = parsed_args.block_k - group_m = parsed_args.group_m - split_k = parsed_args.split_k - num_warps = parsed_args.num_warps - test_gemm(M, N, K, block_m, block_n, block_k, group_m, split_k, num_warps, dtype) - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tune_gemm/tune_gemm.py index 3fdd7da082b5..d49823306e3a 100755 --- a/python/perf-kernels/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tune_gemm/tune_gemm.py @@ -16,8 +16,26 @@ import multiprocessing import pandas as pd -from utils.file_generator import * -from utils.utils import * +from utils.file_generator import ( + gen_configStr, + generate_compile_driver, + generate_matmul_kernels, + generate_profile_tasks, + read_config, +) +from utils.utils import ( + get_default_tuning_result_filename, + get_filename_compile_driver, + get_filename_myKernels, + get_filename_profile_driver, + name_to_tl_types, + patch_triton_compiler, + run_bash_command, + run_bash_command_wrapper, + tl_to_torch_types, + TORCH_HAS_FP8E4B8, + TORCH_HAS_FP8E5B16, +) def is_hip_available(): @@ -56,26 +74,10 @@ def get_full_tuning_space(): for matrix_instr_nonkdim in matrix_instr_nonkdim_range: for kpack in kpack_range: configs.append({ - 'BLOCK_SIZE_M': - block_m, - 'BLOCK_SIZE_N': - block_n, - 'BLOCK_SIZE_K': - block_k, - 'GROUP_SIZE_M': - group_m, - 'SPLIT_K': - split_k, - 'num_warps': - num_warps, - 'num_stages': - num_stages, - 'waves_per_eu': - waves_per_eu, - 'matrix_instr_nonkdim': - matrix_instr_nonkdim, - 'kpack': - kpack + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': + block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps': + num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu, + 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack }) return configs @@ -106,7 +108,6 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): num_warps = config.get("num_warps") num_stages = config.get("num_stages") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - kpack = config.get("kpack") if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: @@ -173,9 +174,8 @@ def extract_kernel_time(M, N, K, config, df, bias_size): # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should # not need below two lines cols = [ - 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', - 'tid', 'grd', 'wgr', 'lds', 'scr', 'arch_vgpr', 'accum_vgpr', 'sgpr', - 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', + 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' ] df.columns = cols configStr = gen_configStr(config) @@ -202,51 +202,30 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): jobId += ngpus -def tune_gemm_config(M, - N, - K, - col_a, - col_b, - dtype_a, - dtype_b, - dtype_c, - init_type, - configs, - run_bench, - jobs, - iters, - skipWarmup, - verbose=0, - num_threads=32, - gpus=[0], - rotating_buffer_size=256, - bias_size=0, +def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, run_bench, jobs, iters, + skipWarmup, verbose=0, num_threads=32, gpus=[0], rotating_buffer_size=256, bias_size=0, icache_flush=False): # precompile the kernels in parallel start_time = datetime.now() if not skipWarmup: # Generate kernel out of all configs - fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, - dtype_b, dtype_c, init_type, configs, + fname = generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, rotating_buffer_size, bias_size) - run_bash_command(f"python {fname} -n {num_threads}", - capture=(verbose < 2)) + run_bash_command(f"python {fname} -n {num_threads}", capture=(verbose < 2)) compile_end = datetime.now() compile_time = compile_end - start_time if verbose: print(f"compile time: {compile_time}", flush=True) # Generate kernels out of all configs - generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, configs, jobs, iters, run_bench, + generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush) # profile generated kernels running = [ - multiprocessing.Process(target=profile_batch_kernels, - args=(M, N, K, gpu_id, gpus, jobs, verbose)) + multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, gpu_id, gpus, jobs, verbose)) for gpu_id in gpus ] for p in running: @@ -266,20 +245,12 @@ def tune_gemm_config(M, tasks = [] idx = 0 df_prof = [ - pd.read_csv(f"results_{i}.csv", - skiprows=1, - header=None, - delimiter=',', - quotechar='"', - escapechar='\\') for i in range(jobs) + pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') + for i in range(jobs) ] for config in configs: file_idx = idx % jobs - tasks += [ - thread_pool.apply_async(extract_kernel_time, - args=(M, N, K, config, df_prof[file_idx], - bias_size)) - ] + tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))] idx += 1 thread_pool.close() thread_pool.join() @@ -293,9 +264,7 @@ def tune_gemm_config(M, bestConfig = config else: min_us = -1 - print( - f"invalid config(post processing): SIZE {M} {N} {K}: {config}", - flush=True) + print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True) post_end = datetime.now() post_time = post_end - profile_end if verbose: @@ -309,8 +278,7 @@ def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): torch.cuda.manual_seed(seed) @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements input = tl.load(input_ptr + offsets, mask=mask) @@ -319,13 +287,11 @@ def copy_kernel(input_ptr, output_ptr, n_elements, def init_by_size_and_type(size, dtype, init_type): if init_type == 'hpl': - return torch.empty(size, device='cuda', - dtype=dtype).uniform_(-0.5, 0.5) + return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) # This init type has element[i] in row[j] equal to sin(i+j*N) elif init_type == 'trig_float': M, N = size - return torch.reshape(torch.arange(0, M * N), - (M, N)).sin().to(dtype=dtype, device='cuda') + return torch.reshape(torch.arange(0, M * N), (M, N)).sin().to(dtype=dtype, device='cuda') elif init_type == 'zeros': return torch.zeros(size, dtype=dtype, device='cuda') elif init_type == "randn": @@ -334,8 +300,7 @@ def init_by_size_and_type(size, dtype, init_type): else: raise ValueError("Bad matrix initialization type.") - raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), - torch.float32, init_type) + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), torch.float32, init_type) if needTrans: raw_data = raw_data.T if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ @@ -356,19 +321,8 @@ def init_by_size_and_type(size, dtype, init_type): # generate inputs/outputs according to rotating tensor size -def gen_rotating_tensors(M, - N, - K, - dtype_a, - need_Trans_a, - dtype_b, - need_Trans_b, - dtype_c, - seed, - init_type, - rotating_buffer_size, - bias_size, - device='cuda'): +def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b, dtype_c, seed, init_type, + rotating_buffer_size, bias_size, device='cuda'): a_size = M * K * type_name_to_bytes(dtype_a) b_size = K * N * type_name_to_bytes(dtype_b) c_size = M * N * type_name_to_bytes(dtype_c) @@ -384,50 +338,23 @@ def gen_rotating_tensors(M, c = [] bias = [] for i in range(block_count): - in_a, in_a_fp16 = gen_input(M, - K, - dtype_a, - need_Trans_a, - 1, - init_type, - device='cuda') + in_a, in_a_fp16 = gen_input(M, K, dtype_a, need_Trans_a, 1, init_type, device='cuda') a.append(in_a) - in_b, in_b_fp16 = gen_input(K, - N, - dtype_b, - need_Trans_b, - 2, - init_type, - device='cuda') + in_b, in_b_fp16 = gen_input(K, N, dtype_b, need_Trans_b, 2, init_type, device='cuda') b.append(in_b) - out_c = torch.zeros((M, N), - dtype=tl_to_torch_types[name_to_tl_types[dtype_c]], - device='cuda') + out_c = torch.zeros((M, N), dtype=tl_to_torch_types[name_to_tl_types[dtype_c]], device='cuda') c.append(out_c) if bias_size > 0: - bs, bs_fp16 = gen_input(M, - 1, - dtype_b, - need_Trans_b, - 2, - init_type, - device='cuda') - bias.append(bs.squeeze()) - - in_outs = { - "rotating_num": block_count, - "input_a": a, - "input_b": b, - "output_c": c, - "bias": bias - } + bs, bs_fp16 = gen_input(M, 1, dtype_b, need_Trans_b, 2, init_type, device='cuda') + bias.append(bs.squeeze(dim=1)) + + in_outs = {"rotating_num": block_count, "input_a": a, "input_b": b, "output_c": c, "bias": bias} return in_outs -def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, - num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, - use_bias): +def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, use_bias): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" #assert a.is_contiguous(), "Matrix A must be contiguous" @@ -439,37 +366,15 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k stride_bias = bias.stride(0) if use_bias else 0 EVEN_K = K % block_k == 0 - matmul_kernel[grid](a, - b, - c, - bias, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - stride_bias=stride_bias, - BLOCK_SIZE_M=block_m, - BLOCK_SIZE_N=block_n, - BLOCK_SIZE_K=block_k, - GROUP_SIZE_M=group_m, - SPLIT_K=split_k, - num_warps=num_warps, - num_stages=num_stages, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - BIAS=use_bias, - EVEN_K=EVEN_K) + matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), + c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, + BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps, + num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, + kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) return c -def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, config, bias_vector, verbose): +def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose): block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( config) use_bias = bias_vector @@ -480,22 +385,13 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') bias = None if use_bias: - bias, bias_fp16 = gen_input(M, - 1, - dtype_b, - col_b, - 2, - init_type, - device='cuda') - bias = bias.squeeze() - bias_fp16 = bias.squeeze() + bias, bias_fp16 = gen_input(M, 1, dtype_b, col_b, 2, init_type, device='cuda') + bias = bias.squeeze(dim=1) + bias_fp16 = bias.squeeze(dim=1) # Allocates output. - c = torch.zeros((M, N), - device=a.device, - dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) - triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, - split_k, num_warps, num_stages, waves_per_eu, - mfmaInstrSize, kpack, use_bias) + c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, + waves_per_eu, mfmaInstrSize, kpack, use_bias) torch_output = torch.matmul(a_fp16, b_fp16) if use_bias: torch_output += bias_fp16[:, None] @@ -506,10 +402,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, size_str = '' if verbose: size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' - if torch.allclose(triton_output.to(torch.float16), - torch_output, - atol=atol, - rtol=rtol): + if torch.allclose(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol): print(f'{size_str} Correct✅') else: print(f"triton_output={triton_output}") @@ -526,111 +419,41 @@ def parse_args(): parser.add_argument("-m", type=int, default=0) parser.add_argument("-n", type=int, default=0) parser.add_argument("-k", type=int, default=0) - parser.add_argument("-col_a", - action='store_true', - default=False, - help='whether matrix a is column major') - parser.add_argument("-col_b", - action='store_true', - default=False, - help='whether matrix b is column major') - parser.add_argument("-dtype_a", - type=str, - default='fp16', - help="matrix a element data type") - parser.add_argument("-dtype_b", - type=str, - default='fp16', - help="matrix b element data type") - parser.add_argument("-dtype_c", - type=str, - default='fp16', - help="output element data type") - parser.add_argument("--ngpus", - type=int, - default=0, - help='number of GPUs used in the profiling step') - parser.add_argument("--gpu_ids", - type=lambda s: [int(id) for id in s.split(',')], - default=[], + parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') + parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') + parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") + parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") + parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") + parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], help='list of gpu ids to use for tuning') - parser.add_argument("--gemm_size_file", - type=str, - default="", - help='yaml file to indicate matrix size') - parser.add_argument("--o", - type=str, - default='', - help='yaml file to store tuning results') - parser.add_argument("--keep", - action='store_true', - default=False, - help='keep generated files') - parser.add_argument("--compare", - action='store_true', - default=False, - help="Whether check result correctness") - parser.add_argument( - "--compare_wo_tuning", - action='store_true', - default=False, - help="Whether check result correctness without tuning.") - parser.add_argument("--benchmark", - action='store_true', - default=False, - help="Benchmark the given config") - parser.add_argument( - "--time_breakdown", - action='store_true', - default=False, - help="Show detailed time breakdown of each step during the tuning") - parser.add_argument( - "--verbose", - action='store_true', - default=False, - help="enables time_breakdown and additional logging messages") - parser.add_argument( - "--num_threads", - type=int, - default=32, - help= - "number of threads to use for kernel compilation and post processing") - parser.add_argument("--jobs", - type=int, - default=1, - help="number of tasks during the profiling process") - parser.add_argument("--iters", - type=int, - default=1000, - help="number of iterations used in --benchmark mode") - parser.add_argument( - "--init_type", - type=str, - default='randn', - choices=['randn', 'hpl', 'trig_float', 'zeros'], - help="Input tensor initialization (default normal distribution)") + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--o", type=str, default='', help='yaml file to store tuning results') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--compare_wo_tuning", action='store_true', default=False, + help="Whether check result correctness without tuning.") + parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") + parser.add_argument("--time_breakdown", action='store_true', default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument("--verbose", action='store_true', default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=32, + help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process") + parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode") + parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'], + help="Input tensor initialization (default normal distribution)") parser.add_argument( - "--rotating_tensor", - type=int, - default=0, - help="total size (MB) of all tensors (a, b, c, bias)." + "--rotating_tensor", type=int, default=0, help="total size (MB) of all tensors (a, b, c, bias)." " The default value is 0 (no rotating tensor)." " When set, it needs to be larger than the L1, L2, MALL size)") - parser.add_argument("--bias_vector", - action='store_true', - default=False, - help="apply bias vector") - parser.add_argument("--icache_flush", - action='store_true', - default=False, + parser.add_argument("--bias_vector", action='store_true', default=False, help="apply bias vector") + parser.add_argument("--icache_flush", action='store_true', default=False, help="apply icache flush in tuning performance") - parser.add_argument("--no_warmup", - action='store_true', - default=False, + parser.add_argument("--no_warmup", action='store_true', default=False, help="Whether we want to skip the compilation stage") - parser.add_argument("--hack_triton_compiler", - action='store_true', - default=False, + parser.add_argument("--hack_triton_compiler", action='store_true', default=False, help="Modify the triton source to avoid backend query") args = parser.parse_args() if not args.o: @@ -719,10 +542,8 @@ def main(): dtype_a = args.dtype_a dtype_b = args.dtype_b dtype_c = args.dtype_c - if not dtype_a in name_to_tl_types or not dtype_b in name_to_tl_types or not dtype_c in name_to_tl_types: - print( - f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}" - ) + if dtype_a not in name_to_tl_types or dtype_b not in name_to_tl_types or dtype_c not in name_to_tl_types: + print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}") print("Supported types: ", list(name_to_tl_types.keys())) sys.exit(1) rotating_buffer_size = args.rotating_tensor @@ -759,10 +580,8 @@ def main(): if args.compare_wo_tuning: for (M, N, K, col_a, col_b, myConfig) in mnks: if myConfig is None: - raise Exception( - "kernel config is None, need to provide a tuning config") - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, myConfig, bias_vector, True) + raise Exception("kernel config is None, need to provide a tuning config") + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, bias_vector, True) return configs_full = get_full_tuning_space() @@ -775,8 +594,7 @@ def main(): print("trans M N K TFLOPS us") f_results.write("trans,M,N,K,TFLOPS,us\n") else: - print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", - flush=True) + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) f_results.close() @@ -800,14 +618,11 @@ def main(): start_local_time = datetime.now() # Obtain a pruned tuning space according to gemm size # If running benchmark, use the provided config - pruned_configs = [myConfig] if run_bench else prune_configs( - M, N, K, configs_full, type_name_to_bytes(dtype_a), - type_name_to_bytes(dtype_b)) + pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) ## Only append new configs from the current gemm size - delta_configs = [ - config for config in pruned_configs if config not in configs - ] + delta_configs = [config for config in pruned_configs if config not in configs] configs += delta_configs ## Append new configs into the tuning space @@ -817,12 +632,9 @@ def main(): row_b_str = 'N' if col_b else 'T' size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' if not run_bench: - print(f"{size_str} nConfigs: {len(pruned_configs)}", - end=" ", - flush=True) + print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) else: - print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", - end="") + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") f_results.write(f"{row_a_str}{row_b_str},{M},{N},{K},") # The main tuning funtion for one gemm size @@ -834,26 +646,9 @@ def main(): # we consider bias size as M for now. bias_size = M if bias_vector else 0 minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( - M, - N, - K, - col_a, - col_b, - dtype_a, - dtype_b, - dtype_c, - init_type, - pruned_configs, - run_bench, - jobs, - iters, - skipWarmup, - num_threads=args.num_threads, - gpus=gpus, - verbose=verbose_level, - rotating_buffer_size=rotating_buffer_size, - bias_size=bias_size, - icache_flush=icache_flush) + M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, pruned_configs, run_bench, jobs, iters, + skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level, + rotating_buffer_size=rotating_buffer_size, bias_size=bias_size, icache_flush=icache_flush) # post processing the numbers perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) @@ -861,41 +656,29 @@ def main(): formatted_tflops = format_output(tri_tflops) minTime = format_output(minTime) if not run_bench: - print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', - end=" ", - flush=True) + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) bestConfig_compact_str = gen_configStr(bestConfig) if not run_bench: - print(f'best_config: {bestConfig_compact_str}', - end=" ", - flush=True) + print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) # write best config to tuning_results.yaml if run_bench: print(f"{formatted_tflops} {minTime}") f_results.write(f"{formatted_tflops},{minTime}\n") - sizeDict = { - 'M': M, - 'N': N, - 'K': K, - 'rowMajorA': row_a_str, - 'rowMajorB': row_b_str - } + sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} sizeDict.update(bestConfig) if not run_bench: f_results.write("- " + str(sizeDict) + " ") - f_results.write( - f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') # remove generated files if asked to if not keepTmp: if not skipWarmup: os.remove(get_filename_compile_driver()) try: - os.remove(get_filename_compile_driver() + - ".failed_configs") + os.remove(get_filename_compile_driver() + ".failed_configs") except OSError: pass for i in range(jobs): @@ -907,8 +690,8 @@ def main(): # Check correctness if asked to if args.compare: print("correctness: ", end=" ", flush=True) - test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, bestConfig, bias_vector, False) + test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, bias_vector, + False) elif not run_bench: print("", flush=True) @@ -928,9 +711,7 @@ def main(): print(f"Total tuning time (h:m:s): {tuning_time}") if hack_triton: - print( - "Triton compiler is hacked, don't forget to git restore the changes :)" - ) + print("Triton compiler is hacked, don't forget to git restore the changes :)") if __name__ == '__main__': diff --git a/python/perf-kernels/tune_gemm/tune_gemm.sh b/python/perf-kernels/tune_gemm/tune_gemm.sh deleted file mode 100755 index b49b57aa0aa3..000000000000 --- a/python/perf-kernels/tune_gemm/tune_gemm.sh +++ /dev/null @@ -1,27 +0,0 @@ -#! /bin/bash - -## $1: driver program -## $2: M -## $3: N -## $4: K -## $5: 1: reduced tuning space - -if [[ $# -lt 4 ]];then - echo "Usage: ./tune_gemm.sh M N K" - exit -fi - -DRIVER=$1 -M=$2 -N=$3 -K=$4 -reduceSpace=$5 - -DRIVER=$(echo $DRIVER | sed -e "s/matmul_grouped.py/matmul.py/g") - -# $DRIVER is the actual tuning scripts, it is the file matmul.py -# -mnk are the size of input matrices, matrix (m, k) x (k, n) -# --specify_size means using -mnk to specify size of input matrices -# --rocprof means using rocprof to measure kernel time. If not set, -# kernel time is from do_bench() -python $DRIVER -m $M -n $N -k $K --specify_size --rocprof diff --git a/python/perf-kernels/tune_gemm/utils/file_generator.py b/python/perf-kernels/tune_gemm/utils/file_generator.py index eea92cf6bf48..1011bc9df805 100644 --- a/python/perf-kernels/tune_gemm/utils/file_generator.py +++ b/python/perf-kernels/tune_gemm/utils/file_generator.py @@ -1,5 +1,13 @@ import os -from .utils import * + +from .utils import ( + get_filename_compile_driver, + get_filename_myKernels, + get_filename_profile_driver, + get_filename_without_extension, + name_to_tl_types, + tl_to_torch_types, +) def read_config(config): @@ -44,20 +52,15 @@ def generate_matmul_kernels(configs): import triton.language as tl""" f_kernel.write(import_str) - with open( - os.path.dirname(os.path.abspath(__file__)) + - "/../matmul_kernel.py") as file: + with open(os.path.dirname(os.path.abspath(__file__)) + "/../matmul_kernel.py") as file: matmul_kernel_code = file.read() for config in configs: configStr = gen_configStr(config) # Copy the matmul_kernel with name replaced - matmul_kernel_config = matmul_kernel_code.replace( - "matmul_kernel", f"matmul_kernel_{configStr}") - matmul_kernel_config = matmul_kernel_config.replace( - "import triton.language as tl", "") - matmul_kernel_config = matmul_kernel_config.replace( - "import triton", "") + matmul_kernel_config = matmul_kernel_code.replace("matmul_kernel", f"matmul_kernel_{configStr}") + matmul_kernel_config = matmul_kernel_config.replace("import triton.language as tl", "") + matmul_kernel_config = matmul_kernel_config.replace("import triton", "") f_kernel.write(matmul_kernel_config) f_kernel.close() @@ -65,8 +68,7 @@ def generate_matmul_kernels(configs): ## construct the configStr and generate the wrapper function matmul_{configStr}() ## If `warmup` is set, the generated kernel will be **compiled** -def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, - dtype_c, bias_size, warmup): +def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup): block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config( config) @@ -141,8 +143,7 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): return configStr, matmul_def_str -def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, configs, rotating_buffer_size, +def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, rotating_buffer_size, bias_size): """ Generate a single file that contains all kernels in the tuning space. @@ -167,8 +168,8 @@ def generate_compile_driver(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, for config in configs: EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, matmul_def_str = gen_kernel_and_configStr_from_config( - config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, True) + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, + bias_size, True) # Copy the matmul_kernel with name replaced f_kernel.write(matmul_def_str + "\n") @@ -242,8 +243,7 @@ def main(): return filename -def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, - init_type, configs, jobs, iters, run_bench, +def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, configs, jobs, iters, run_bench, rotating_buffer_size, bias_size, icache_flush): """ Open {len(jobs)} files @@ -280,8 +280,8 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, for config in configs: file_idx = idx % jobs EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False - configStr, matmul_def_str = gen_kernel_and_configStr_from_config( - config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, False) + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, + bias_size, False) # Copy the matmul_kernel with name replaced f_kernel[file_idx].write(matmul_def_str + "\n") idx += 1 @@ -311,7 +311,6 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, # call all matmul_xxx functions idx = 0 runs = iters if run_bench else 200 - call_icache_flush = 'icache_flush()' if icache_flush else '' for config in configs: configStr = gen_configStr(config) matmul_call_str = f""" @@ -324,7 +323,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, bias = tensors['bias'][i % rotating_num] if bias_size > 0 else None bias_stride = bias.stride(0) if bias_size > 0 else 0""" if icache_flush: - matmul_call_str += f""" + matmul_call_str += """ icache_flush()""" matmul_call_str += f""" d = matmul_{configStr}(a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), bias_stride)""" diff --git a/python/perf-kernels/tune_gemm/utils/utils.py b/python/perf-kernels/tune_gemm/utils/utils.py index 9b6b50ea626b..bcebf9a3ff8d 100644 --- a/python/perf-kernels/tune_gemm/utils/utils.py +++ b/python/perf-kernels/tune_gemm/utils/utils.py @@ -34,7 +34,7 @@ def run_bash_command_wrapper(commandstring, capture=True): try: run_bash_command(commandstring, capture) - except subprocess.CalledProcessError as e: + except subprocess.CalledProcessError: if not capture: print(f"running {commandstring} one more time") run_bash_command(commandstring, capture) @@ -42,16 +42,9 @@ def run_bash_command_wrapper(commandstring, capture=True): def run_bash_command(commandstring, capture=True): if capture: - proc = subprocess.run(commandstring, - shell=True, - check=True, - executable='/bin/bash', - stdout=subprocess.PIPE) + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) return proc.stdout.splitlines() - proc = subprocess.run(commandstring, - shell=True, - check=True, - executable='/bin/bash') + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') return None @@ -111,5 +104,7 @@ def patch_triton_compiler(): cuda_driver_filename = os.path.join(triton_dir, "../third_party/nvidia/backend/", "driver.py") run_bash_command(f"sed -i 's/import torch/return True/g' {hip_driver_filename}") - run_bash_command(f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}") + run_bash_command( + f"sed -i 's/device = self.get_current_device()/return GPUTarget(\"hip\", \"{target.arch}\", 64)/g' {hip_driver_filename}" + ) run_bash_command(f"sed -i 's/import torch/return False/g' {cuda_driver_filename}") From 15cb3a89720c7f4517b7419a917bd20f4ced3f30 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 19 Aug 2024 14:50:31 -0500 Subject: [PATCH 12/20] [tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 (#630) * Change to rocprofv1 * improve post processing of rocprof results - set --iters=200 as default. This is enough since the time is stable after the first few runs. - Filter out kernel time that is too large. We use the first kernel time as the threshold. There must be something wrong with the kernel if its elapsedTime is larger than the first run. We need to investigate the reason. For now, just filter them out. * Add xcd-based pid remapping * Enable EVEN_K=false for large gemms * Update readme --- python/perf-kernels/tune_gemm/README.md | 20 +++++++++- .../perf-kernels/tune_gemm/matmul_kernel.py | 16 +++++++- python/perf-kernels/tune_gemm/tune_gemm.py | 39 ++++++++----------- .../tune_gemm/utils/file_generator.py | 17 ++++++-- 4 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/perf-kernels/tune_gemm/README.md b/python/perf-kernels/tune_gemm/README.md index da45dcda5c3c..c22382143544 100644 --- a/python/perf-kernels/tune_gemm/README.md +++ b/python/perf-kernels/tune_gemm/README.md @@ -1,8 +1,9 @@ -# GEMM tuning script (current v3.3) +# GEMM tuning script (current v3.4) ## matmul kernel The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tune_gemm/matmul_kernel.py), which includes the following features: +- XCD-based pid remapping - grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations). - split-k algorithm, which is controlled by `SPLIT_K`. @@ -144,7 +145,7 @@ The default value is 1000. The general idea of the tuning script can be summarized as - Compile all the kernels in the tuning space in parallel. -- Divide the tuning space into tasks and invoke `rocprofv2` once per +- Divide the tuning space into tasks and invoke `rocprof` once per task. This will save invocation overhead of the profiler. - Profile tasks in parallel on multiple GPUs. @@ -309,6 +310,21 @@ places: - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) +# GEMM Tuning Script v3.4 + +## API changes + +No API changes + +## Implementation changes + +- Now the matmul_kernel supports XCD-based pid remapping. Details with experiments +will be added later. +- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details. +- Improved the post-procesing logic to filter out the "spikes" in the profiling results. +- Reduced the number of iterations in both tuning and benchmark mode (120 and 200). + + # One config running script `one_config.py` is a script that runs one given matmul config. diff --git a/python/perf-kernels/tune_gemm/matmul_kernel.py b/python/perf-kernels/tune_gemm/matmul_kernel.py index 336a643dca50..1d9902bc2de6 100644 --- a/python/perf-kernels/tune_gemm/matmul_kernel.py +++ b/python/perf-kernels/tune_gemm/matmul_kernel.py @@ -6,11 +6,22 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, - EVEN_K: tl.constexpr): + EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_XCDS != 1: + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = GRID_MN // NUM_XCDS + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + pid = xcd * pids_per_xcd + local_pid + if GROUP_SIZE_M == 1: pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + if SPLIT_K == 1: offs_k = tl.arange(0, BLOCK_SIZE_K) else: diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tune_gemm/tune_gemm.py index d49823306e3a..291096b3d7af 100755 --- a/python/perf-kernels/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tune_gemm/tune_gemm.py @@ -54,7 +54,7 @@ def get_full_tuning_space(): block_k_range = [16, 32, 64, 128, 256] split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] + group_m_range = [1, 2, 4, 8, 16, 32] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future @@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): if num_warps < 4: continue # check if tiling is integer multiple of GEMM size because we have no boundary check - if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: + if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0: continue pruned_configs.append(config) @@ -169,20 +169,15 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 -def extract_kernel_time(M, N, K, config, df, bias_size): - # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 - # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should - # not need below two lines - cols = [ - 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', - 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' - ] - df.columns = cols +def extract_kernel_time(M, N, K, config, df): configStr = gen_configStr(config) - filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() - filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] - meanTime = filtered_df['DurationNs'].tail(100).mean() - return config, meanTime + df = df[df['KernelName'].str.contains(configStr)] + + first_value = df['DurationNs'].iloc[0] + filtered_data = df['DurationNs'][df['DurationNs'] <= first_value] + new_meanTime = filtered_data.tail(100).mean() + + return config, new_meanTime def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): @@ -197,7 +192,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): if verbose: print(f"profiling {kernel_name} on GPU {gpuid}") run_bash_command_wrapper( - f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", + f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}", capture=(verbose < 2)) jobId += ngpus @@ -244,13 +239,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type thread_pool = multiprocessing.Pool(processes=num_threads) tasks = [] idx = 0 - df_prof = [ - pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') - for i in range(jobs) - ] + df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)] for config in configs: file_idx = idx % jobs - tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))] + tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))] idx += 1 thread_pool.close() thread_pool.join() @@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k stride_bias = bias.stride(0) if use_bias else 0 EVEN_K = K % block_k == 0 + num_xcds = 1 if split_k > 1 else 8 matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) + kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds) return c @@ -441,7 +434,7 @@ def parse_args(): parser.add_argument("--num_threads", type=int, default=32, help="number of threads to use for kernel compilation and post processing") parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process") - parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode") + parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode") parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'], help="Input tensor initialization (default normal distribution)") parser.add_argument( diff --git a/python/perf-kernels/tune_gemm/utils/file_generator.py b/python/perf-kernels/tune_gemm/utils/file_generator.py index 1011bc9df805..d92079dab9a0 100644 --- a/python/perf-kernels/tune_gemm/utils/file_generator.py +++ b/python/perf-kernels/tune_gemm/utils/file_generator.py @@ -76,6 +76,10 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype use_bias = bias_size > 0 + ## Let's enable xcd-based pid remapping only when split-K is NOT used + ## Also #xcd is fixed to 8. If we are tuning for MI308, please change it to 4 + num_xcds = 1 if split_k > 1 else 8 + if warmup: torch_dtype_a = 'fp16' torch_dtype_b = 'fp16' @@ -89,6 +93,7 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype matmul_def_str = f""" def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + grid_mn = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}) matmul_kernel_{configStr}.warmup( {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, M, N, K, @@ -103,8 +108,10 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): waves_per_eu = {waves_per_eu}, matrix_instr_nonkdim = {mfmaInstrSize}, kpack = {kpack}, - BIAS={use_bias}, - EVEN_K={EVEN_K}, + BIAS = {use_bias}, + EVEN_K = {EVEN_K}, + GRID_MN = grid_mn, + NUM_XCDS = {num_xcds}, grid=(1,), ) return None @@ -136,7 +143,9 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): matrix_instr_nonkdim = {mfmaInstrSize}, kpack = {kpack}, BIAS = {use_bias}, - EVEN_K = {EVEN_K} + EVEN_K = {EVEN_K}, + GRID_MN = grid[0], + NUM_XCDS = {num_xcds} ) return c """ @@ -310,7 +319,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, ini # call all matmul_xxx functions idx = 0 - runs = iters if run_bench else 200 + runs = iters if run_bench else 120 for config in configs: configStr = gen_configStr(config) matmul_call_str = f""" From 177d0bd5ceef8bb9cfcbc73d4479fc9f8ff26969 Mon Sep 17 00:00:00 2001 From: xiaohuguo2023 <149615094+xiaohuguo2023@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:20:04 +0100 Subject: [PATCH 13/20] add barrier to fix racing for spinning locks (#632) --- python/perf-kernels/streamk/streamk_kernel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py index 138e6540e203..42b861950a9b 100644 --- a/python/perf-kernels/streamk/streamk_kernel.py +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -201,6 +201,7 @@ def streamk_gemm( rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] tl.store(P_, acc) + tl.debug_barrier() tl.atomic_xchg(locks + pid, 1) start_iter = end_iter From e42690dba7168be101645052af752dec9d47991f Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Thu, 8 Aug 2024 09:47:29 +0000 Subject: [PATCH 14/20] Softmax kernel --- .../amd_perf_kernel_Integration_tests.yml | 2 + python/perf-kernels/README.md | 4 + python/perf-kernels/softmax.py | 219 ++++++++++++++++++ 3 files changed, 225 insertions(+) create mode 100644 python/perf-kernels/softmax.py diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index 956ff8903115..266018a2cf0c 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -126,6 +126,8 @@ jobs: - name: Run Perf Kernels Unit Tests run: | pytest -vvv ./python/perf-kernels/flash-attention.py + pytest -vvvv ./python/perf-kernels/softmax.py - name: Run Perf Kernels Benchmark run: | python ./python/perf-kernels/flash-attention.py + python ./python/perf-kernels/softmax.py diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index b8f930ef94ea..663f5333cc13 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -69,3 +69,7 @@ small block sizes aren't natively supported by `tl.dot` operator. Despite being numerically correct, this kernel performed worse than a corresponding GEMM kernel that used `tl.dot` with minimum block size equal to $16$. + +## `softmax.py` + +Kernel that implements Softmax over a row of tensor. diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py new file mode 100644 index 000000000000..bd00f24c42fc --- /dev/null +++ b/python/perf-kernels/softmax.py @@ -0,0 +1,219 @@ +import argparse +import torch +import sys +import pytest + +import triton +import triton.language as tl +from triton.runtime import driver + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + +def get_cuda_autotune_config(): + return [ + triton.Config({}, num_warps=4, num_stages=1), + triton.Config({}, num_warps=8, num_stages=1), + triton.Config({}, num_warps=16, num_stages=1), + ] + + +def get_hip_autotune_config(): + return [ + triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + for row_idx in tl.range(row_start, n_rows, row_step): + row_start_ptr = input_ptr + row_idx * input_row_stride + input_ptrs = row_start_ptr + col_offsets + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + row = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + output_ptrs = tl.multiple_of(output_ptrs, (16, )) + tl.store(output_ptrs, softmax_output, mask=mask) + + +device = torch.cuda.current_device() +properties = driver.active.utils.get_device_properties(device) +NUM_SM = properties["multiprocessor_count"] + + +def softmax(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + y = torch.empty_like(x) + + #Persistent kernel. Simply, set num of programs equal to number of streaming multi-processors + num_programs = min(NUM_SM, n_rows) + + grid = lambda meta: (num_programs, ) + softmax_kernel[grid]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + BLOCK_SIZE, + ) + + return y + + +def run_softmax(M, N): + print(f"Running Softmax on shape ({M},{N})") + torch.manual_seed(0) + x = torch.randn(M, N, device='cuda') + y_triton = softmax(x) + + return y_triton + + +#pytest +@pytest.mark.parametrize('M, N', [ + (1823, 781), + (1, 1), + (128, 1), + (1, 128), + (8192, 8192), + (4096, 8192), + (359, 1), + (1, 359), + (1, 131072), +]) +def test_softmax(M, N): + torch.manual_seed(0) + x = torch.randn(M, N, device='cuda') + y_triton = softmax(x) + y_torch = torch.softmax(x, axis=1) + assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) + + +#Benchmark +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def run_benchmark(args): + config = [] + if (args.M_benchmark): + val = args.M_start + x_vals_list = [] + while val <= args.M_end: + x_vals_list.append(val) + val *= args.M_step + mn_args = {'N': args.N_start} + plot_name = str("softmax-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) + + "-" + str(args.M_end) + "-" + str(args.M_step)) + x_names = ['M'] + else: + x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)] + mn_args = {'M': args.M_start} + plot_name = str("softmax-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) + + "-" + str(args.N_end) + "-" + str(args.N_step)) + x_names = ['N'] + dtype = arg_to_torch_dtype[args.dtype] + + print(plot_name) + config.append( + triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg='provider', + line_vals=['triton', 'torch'], + line_names=[ + "Triton", + "Torch", + ], + styles=[('blue', '-'), ('green', '-')], + ylabel="GB/s", + plot_name=plot_name, + args=mn_args, + )) + + @triton.testing.perf_report(config) + def benchmark(M, N, provider): + x = torch.randn(M, N, device='cuda', dtype=dtype) + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + if provider == 'torch': + ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) + if provider == 'triton': + ms = triton.testing.do_bench(lambda: softmax(x)) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms) + + benchmark.run(save_path=".", show_plots=True, print_data=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark Softmax", + allow_abbrev=False, + ) + + parser.add_argument('-M', "--M_start", default="1", type=int) + parser.add_argument('-Ms', "--M_step", default="2", type=int) + parser.add_argument('-Me', "--M_end", default="512", type=int) + parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool) + + parser.add_argument('-N', "--N_start", default="1024", type=int) + parser.add_argument('-Ns', "--N_step", default="2048", type=int) + parser.add_argument('-Ne', "--N_end", default="65536", type=int) + + parser.add_argument('-d', "--dtype", default="fp16") + parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) + + return parser.parse_args() + + +def main(): + args = parse_args() + if args.no_benchmark: + run_softmax(args.M_start, args.N_start) + else: + run_benchmark(args) + + +if __name__ == "__main__": + sys.exit(main()) From 3704738d49e26ccc4ab8c2e48ccfdcb1c43bcf59 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 6 Sep 2024 13:41:52 -0500 Subject: [PATCH 15/20] Move utility tools from triton-mlir to main_perf branch (#635) * Move utility tools from triton-mlir to main_perf branch - Plot layout script - occ.sh - amdgcn-cfg * yapf format * More formats * remove executablility of plot_layout.py * Address ruff complains * Move tune_gemm to tools --- .../perf-kernels/tools/amdgcn-cfg/README.md | 14 + .../tools/amdgcn-cfg/amdgcn-cfg.py | 222 +++++ python/perf-kernels/tools/occ.sh | 71 ++ .../perf-kernels/tools/plot-layout/README.md | 117 +++ .../tools/plot-layout/plot_layout.py | 290 ++++++ .../tools/plot-layout/tikzplot.tex | 880 ++++++++++++++++++ .../{ => tools}/tune_gemm/README.md | 0 .../{ => tools}/tune_gemm/icache_flush.py | 0 .../{ => tools}/tune_gemm/matmul_kernel.py | 0 .../{ => tools}/tune_gemm/one_config.py | 0 .../{ => tools}/tune_gemm/tune_gemm.py | 0 .../tune_gemm/utils/file_generator.py | 0 .../{ => tools}/tune_gemm/utils/utils.py | 0 13 files changed, 1594 insertions(+) create mode 100644 python/perf-kernels/tools/amdgcn-cfg/README.md create mode 100644 python/perf-kernels/tools/amdgcn-cfg/amdgcn-cfg.py create mode 100755 python/perf-kernels/tools/occ.sh create mode 100644 python/perf-kernels/tools/plot-layout/README.md create mode 100644 python/perf-kernels/tools/plot-layout/plot_layout.py create mode 100644 python/perf-kernels/tools/plot-layout/tikzplot.tex rename python/perf-kernels/{ => tools}/tune_gemm/README.md (100%) rename python/perf-kernels/{ => tools}/tune_gemm/icache_flush.py (100%) rename python/perf-kernels/{ => tools}/tune_gemm/matmul_kernel.py (100%) rename python/perf-kernels/{ => tools}/tune_gemm/one_config.py (100%) rename python/perf-kernels/{ => tools}/tune_gemm/tune_gemm.py (100%) rename python/perf-kernels/{ => tools}/tune_gemm/utils/file_generator.py (100%) rename python/perf-kernels/{ => tools}/tune_gemm/utils/utils.py (100%) diff --git a/python/perf-kernels/tools/amdgcn-cfg/README.md b/python/perf-kernels/tools/amdgcn-cfg/README.md new file mode 100644 index 000000000000..bea420ea530c --- /dev/null +++ b/python/perf-kernels/tools/amdgcn-cfg/README.md @@ -0,0 +1,14 @@ +# Control Flow Graph Generator from AMDGCN assembly + +The script reads an assembly file and generates a Control Flow Graph (CFG) for each function in the file. The graph can be saved in `dot`, `svg` and `pdf` formats. The nodes of a graph can be represented with 1) just labels or 2) the corresponding assembly code. The edges of a graph can help to identify cycles and, thus, to provide a better navigation through the code. + + +### Basic usage + +``` +python ./amdgcn-cfg.py -i -o / -f [dot|svg|pdf] +``` + +`dot`-files can be visualize with [this](https://dreampuf.github.io/GraphvizOnline) online tool. You just need to copy and paste the content of a generated `dot`-file. + +By default, the nodes are named with basic block labels. Use `-v` or `--verbose` option to add assembly source code to corresponding nodes. diff --git a/python/perf-kernels/tools/amdgcn-cfg/amdgcn-cfg.py b/python/perf-kernels/tools/amdgcn-cfg/amdgcn-cfg.py new file mode 100644 index 000000000000..ae2f65830766 --- /dev/null +++ b/python/perf-kernels/tools/amdgcn-cfg/amdgcn-cfg.py @@ -0,0 +1,222 @@ +import os +import argparse +import re +from collections import OrderedDict +import graphviz + + +class Options: + + def __init__(self, input_file, output_file, verbose, format): + if not os.path.exists(input_file): + raise RuntimeError('input file is not provided') + + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + raise RuntimeError('output directory does not exist') + + self.input_file = input_file + self.output_file = output_file + self.verbose = verbose + self.format = format + self.output_dir = output_dir + + +class Block: + + def __init__(self, label, code): + self.label = label + self.code = code + self.edges = [] + + +class Kernel: + + def __init__(self, kernel_name, blocks): + self.name = kernel_name + self.blocks = blocks + self.cfg = None + + +begin_label = 'Begin' +end_label = 'End' + + +def find_kernel(text): + func_name_expr = r'^([^\s^\.]\w.+):' + func_name = None + start = None + for index, line in enumerate(text): + match = re.search(func_name_expr, line) + if match is not None: + func_name = match[1] + start = index + break + if start is None: + return None, None, None + + end = None + for index, line in enumerate(text): + if re.search(r's_endpgm', line) is not None: + end = index + break + + if end is None: + return None, None, None + + return func_name, text[start:end + 1], end + + +def find_label(kernel): + label = None + index = None + for index, line in enumerate(kernel): + match = re.search(r'^\.(\w+):', line) + if match is not None: + label = match[1] + break + return label, index + + +def get_block_list(kernel): + label, index = find_label(kernel) + + blocks = OrderedDict() + if (index > 1): + blocks[begin_label] = Block(begin_label, kernel[:index - 1]) + + while label is not None: + kernel = kernel[index + 1:] + next_label, next_index = find_label(kernel) + if next_label is None: + code = kernel[index:] + else: + code = kernel[:next_index] + blocks[label] = Block(label, code) + + label = next_label + index = next_index + + blocks[end_label] = Block(end_label, []) + + return blocks + + +def find_terminators(code): + terminator_labels = [] + for line in code: + branch = re.search(r'(c)?branch.*\s+\.?(.*)', line) + if branch is not None: + is_condional = True if len(branch.groups()) == 2 else False + label_idx = 2 if is_condional else 1 + terminator_labels.append(branch[label_idx]) + if not is_condional: + return terminator_labels, True + end = re.search(r's_endpgm', line) + if end is not None: + terminator_labels.append(end_label) + return terminator_labels, True + + return terminator_labels, False + + +def add_edges(kernel): + keys = list(kernel.blocks.keys()) + for index, curr_label in enumerate(keys): + if curr_label == end_label: + continue + + code = kernel.blocks[curr_label].code + terminators, is_last_unconditional = find_terminators(code[:-1]) + + if is_last_unconditional: + # unconditional jump in the middle of the block + break + + # handle the last terminator in the current BB + last_terminator, is_unconditional = find_terminators([code[-1]]) + + is_conditional = not is_unconditional + next_block_label = keys[index + 1] + is_next_covered = next_block_label in terminators + + if last_terminator: + terminators.extend(last_terminator) + if is_conditional and not is_next_covered: + next_block_label = keys[index + 1] + terminators.append(next_block_label) + else: + if not is_next_covered: + next_block_label = keys[index + 1] + terminators.append(next_block_label) + + assert (len(terminators)) + kernel.blocks[curr_label].edges = terminators + + +def generate_cfg(kernel, options): + graph = graphviz.Digraph(f'{kernel.name}') + for curr_label in kernel.blocks: + block = kernel.blocks[curr_label] + asm = [line.strip() for line in block.code] + if options.verbose: + label_text = repr('\n'.join([f'{curr_label}', *asm])) + else: + label_text = curr_label + graph.node(curr_label, shape='rect', labeljust='l', margin='0.01', label=label_text) + + for curr_label in kernel.blocks: + block = kernel.blocks[curr_label] + for edge in block.edges: + graph.edge(curr_label, edge) + + return graph + + +def main(options): + asm = [] + with open(options.input_file, 'r') as file: + context = file.readlines() + for line in context: + asm.append(line[:-1]) + + kernels = [] + last_end_index = 0 + while last_end_index is not None: + func_name, kernel_asm, last_end_index = find_kernel(asm) + if kernel_asm is None: + break + + blocks = get_block_list(kernel_asm) + kernel = Kernel(func_name, blocks) + add_edges(kernel) + + cfg = generate_cfg(kernel, options) + kernel.cfg = cfg + kernels.append(kernel) + asm = asm[last_end_index + 1:] + + for index, kernel in enumerate(kernels): + output_file_name = f'{options.output_file}.kernel-{index}' + if options.format == 'dot': + with open(f'{output_file_name}.dot', 'w') as file: + file.write(str(kernel.cfg)) + file.write('\n') + else: + kernel.cfg.render( + filename=f'{output_file_name}', + format=options.format, + ).replace('\\', '/') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="Generates Control Flow Graph (CFG) from amdgcn assembly file", ) + parser.add_argument("-i", "--input", type=str, default=None, help="input file") + parser.add_argument("-o", "--output", type=str, default=None, help="output file prefix") + parser.add_argument("-v", "--verbose", action='store_true', help='verbose output') + parser.add_argument("-f", "--format", choices=['dot', 'svg', 'pdf'], default="dot", help="output format type") + args = parser.parse_args() + + options = Options(args.input, args.output, args.verbose, args.format) + + main(options) diff --git a/python/perf-kernels/tools/occ.sh b/python/perf-kernels/tools/occ.sh new file mode 100755 index 000000000000..51c8f9095907 --- /dev/null +++ b/python/perf-kernels/tools/occ.sh @@ -0,0 +1,71 @@ +#! /bin/bash + +## $1: input script that contains one kernel + +rm -rf ~/.triton/cache/ + +export MLIR_ENABLE_DUMP=1 +export AMDGCN_ENABLE_DUMP=1 +## Assume CDNA arch +SIMD=4 +LDS_SIZE=65536 +TOTAL_VGPR=512 + +get_occ_per_CU() { + ## $1: vgpr count + vgpr=$1 + occPerEU=$((TOTAL_VGPR/vgpr)) + if [[ $vgpr -gt 256 ]]; then + occPerEU=1 + elif [[ $vgpr -gt 168 ]]; then + occPerEU=2 + elif [[ $vgpr -gt 128 ]]; then + occPerEU=3 + elif [[ $vgpr -gt 96 ]]; then + occPerEU=4 + elif [[ $vgpr -gt 80 ]]; then + occPerEU=5 + elif [[ $vgpr -gt 72 ]]; then + occPerEU=6 + elif [[ $vgpr -gt 64 ]]; then + occPerEU=7 + else + occPerEU=8 + fi + + occPerCU=$((occPerEU*SIMD/num_warps)) + echo $occPerCU +} + +$1 > output.mlir 2>&1 + +LDS_line=$(sed -n '/triton_gpu\.shared\ /p' output.mlir | tail -n 1 | grep -o 'triton_gpu.shared = [0-9]*') +numWarps_line=$(sed -n '/triton_gpu\.num-warps/p' output.mlir | tail -n 1 | grep -o 'triton_gpu.num-warps. = [0-9]*') + +LDS=${LDS_line##*=} +num_warps=${numWarps_line##*=} +echo "LDS: $LDS, num_warps: $num_warps" + +VGPRs=$(sed -n '/vgpr_count/p' output.mlir | tail -n 1 | awk '{print $2}') +SPILLs=$(sed -n '/vgpr_spill/p' output.mlir | tail -n 1 | awk '{print $2}') + +echo "VGPRS: $VGPRs (spill: $SPILLs)" + +occLDSPerCU=$((LDS_SIZE/LDS)) +occVgprPerCU=$(get_occ_per_CU $VGPRs) +occPerCU=$occVgprPerCU +if [ $occLDSPerCU -lt $occVgprPerCU ];then + occPerCU=$occLDSPerCU +fi +occPerEU=$((occPerCU*num_warps/SIMD)) +echo "occupancy: $occPerEU waves/SIMD or $occPerCU workgroups/CU (occLDSPerCU: $occLDSPerCU, occVgprPerCU: $occVgprPerCU)" + +perf=$(tail -n 2 output.mlir) +echo "$perf" + +## remove distracting info from the assembly +sed -i '/local_/! {/\.loc/d}' output.mlir +sed -i '/\.Ltmp.*:/d' output.mlir +sed -i '/AMD clang version/d' output.mlir + +sed -n '/AMDGCN/, $p' output.mlir > output.amdgcn diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md new file mode 100644 index 000000000000..40de35bdb3aa --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -0,0 +1,117 @@ +# Plot script for triton layouts + +This script is used to draw triton layouts in the context of matmul. +Here is the help info from the script. + +```bash +>$ python3 plot_layout.py -h +usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] + [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] + +options: + -h, --help show this help message and exit + -shape SHAPE SHAPE SHAPE + Tensor shape in the form of M,N,K + -plot {blocked,dot,wmma,lds} + choose plot mode + -nonKDim {16,32} mfma instruction dim + -sizePerThread SIZEPERTHREAD SIZEPERTHREAD + -threadsPerWarp THREADSPERWARP THREADSPERWARP + -warpsPerCTA WARPSPERCTA WARPSPERCTA + -order ORDER ORDER + -kWidth {4,8,16} number of elements per thread + -lds_layout {swizzle,padding,none} + choose the LDS data layout + -lds_access {read,write,none} + choose LDS access mode + -wave_size {32,64} choose the wmma instruction mode + -o O output pdf file name (without surfix) + -mfmaTrans If set, then use mfma.trans layout + -keep If set, keep the generated .tex file +``` + +## Installation +This script does not require torch or triton to be installed. The only package +it depends on is latex. On Ubuntu, do +```bash +sudo apt install texlive-full +``` + +## Draw blocked layout (`-plot blocked`) + +Examples: +```bash +python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 +python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 +python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 +``` + +Blocked layouts are used during global load. It is used to describe the layout of the tensor +for pointers and results. +We can provide tensor shape (`-shape M N K`) and blocked layout parameters ( +`-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`). +We can also provide the order of the tensor as `-order x y` to control which dim +is the fastest changing dimension. + +Notes +- All of the gemm dims (M, N, and K) are needed when providing the shape. But only + M and K will be used to plot the layout of the tensor. +- The script does not support the case when threads are loading elements that are + out of the boundary of the tensor dimensions. This means + - For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M + - For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K + + +## Draw mfma operand and result layouts (`-plot dot`) + +Examples: +```bash +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 +python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 +``` + +This mode draws two graphs: +1. The layout of the whole tile for tile A, B, and C +2. The layout of a single mfma block, operands and results of one or more mfma + instructions that share the same accumulating VGPRs. + This view has thread distributions among tensor elements. + +Knobs +- `-kWidth`: the number of elements that will be loaded into one thread at once +- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size +- `-mfmaTrans`: if set, the transposed mfma layout will be plotted. + +Notes +- The layout shows the mapping from the threads/wave to the elements in the + original tensor. It does not care if the elements are arranged in LDS, like + swizzling to avoid bank conflicts. +- The script does not allow settings for data type or k dim of the mfma instruction. + This can be controled by the `-kWidth` flag. + - For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`. + - If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`. + + +## Draw LDS access (`-plot lds`) + +Examples: +```bash +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8 +``` + +Knobs +- `kWidth` here means the vector size when accessing LDS +- Three options for `-lds_layout`: + - `none`: no swizzling, no padding + - `padding`: padding at every 128B + - `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth. +- Three options for `-lds_access`: + - `none`: do not plot access pattern + - `read`: plot accessed elements during ds_read + - `write`: plot accessed elements during ds_write. Note that this needs some infomation from + global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`. + +Notes +- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly. diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py new file mode 100644 index 000000000000..599f92c790e4 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -0,0 +1,290 @@ +import argparse +import sys +import os +import subprocess + + +def draw_preamble_cmd(): + return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone} +\\usepackage{ifthen} +\\usepackage{tikz} +\\usetikzlibrary{arrows.meta,arrows} +\\usetikzlibrary{intersections} +\\usetikzlibrary{calc, quotes} +\\usetikzlibrary{patterns} +\\usepackage{xparse} + +\\ExplSyntaxOn +\\NewExpandableDocumentCommand{\\bitwiseXor}{mm} + { + \\recuenco_bitwise_xor:nn { #1 } { #2 } + } + +\\cs_new:Nn \\recuenco_bitwise_xor:nn + { + \\int_from_bin:e + { + \\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } } + } + } +\\cs_generate_variant:Nn \\int_from_bin:n { e } + +\\cs_new:Nn \\__recuenco_bitwise_xor:nn + { + \\__recuenco_bitwise_xor_binary:ee + { + \\prg_replicate:nn + { + \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 } + } + { 0 } + #1 + } + { + \\prg_replicate:nn + { + \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 } + } + { 0 } + #2 + } + } +\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee } + +\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn + { + \\__recuenco_bitwise_xor_binary:w #1;#2; + } +\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee } + +\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4; + { + \\int_abs:n { #1-#3 } + \\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; } + } + +\\ExplSyntaxOff''' + + +def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): + return f'''\\begin{{document}} + \\begin{{tikzpicture}} + \\def\\scale{{1}} + \\def\\elem{{0.04}} + \\coordinate (C TL) at (0,0); + \\def\\opColorAL{{magenta}} + \\def\\opColorAR{{cyan}} + \\def\\opColorBL{{Maroon}} + \\def\\opColorBR{{BlueGreen}} + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}} + + \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); + \\def\\mfmaTrans{{{trans}}} + + %% Draw zoomed in view of mfma + \\def\\elem{{.16}} + \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} + \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} + \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} + \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kpack}*\\elem, 0)$); + \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}} + + \\end{{tikzpicture}} +\\end{{document}}''' + + +def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order): + return f'''\\begin{{document}} + \\begin{{tikzpicture}} + \\def\\scale{{1}} + \\def\\elem{{0.06}} + \\coordinate (TL) at (0,0); + \\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}} + \\end{{tikzpicture}} +\\end{{document}}''' + + +def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp): + if ldsLayout == 'swizzle': + hasSwizzle = 1 + elif ldsLayout == 'padding': + hasSwizzle = 2 + else: + hasSwizzle = 0 + + if ldsAccess == 'read': + accessMode = 1 + elif ldsAccess == 'write': + accessMode = 2 + else: + accessMode = 0 + + return f'''\\begin{{document}} + \\begin{{tikzpicture}} + \\def\\scale{{1}} + \\def\\M{{{M}}} + \\def\\K{{{K}}} + \\def\\vec{{{kpack}}} + \\def\\hasSwizzle{{{hasSwizzle}}} + \\def\\accessMode{{{accessMode}}} + + \\def\\sizePerThreadK{{{sizePerThread[1]}}} + \\def\\sizePerThreadM{{{sizePerThread[0]}}} + \\def\\threadsPerWarpK{{{threadsPerWarp[1]}}} + + \\def\\elem{{0.18}} + \\coordinate (TL) at (0,0); + \\drawTensorLayoutGlobalMem + \\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$); + \\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}} + \\end{{tikzpicture}} +\\end{{document}}''' + + +def draw_wmma_instr_cmd(waveSize): + wmma_mode = 0 if waveSize == 32 else 1 + return f'''\\begin{{document}} + \\begin{{tikzpicture}} + \\def\\scale{{1}} + \\coordinate (C TL) at (0,0); + \\def\\elem{{0.25}} + \\drawWMMAInstr{{{wmma_mode}}}{{1}} + \\end{{tikzpicture}} +\\end{{document}}''' + + +def run_bash_command(commandstring): + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) + return proc.stdout.splitlines() + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Draw triton layouts", + allow_abbrev=False, + ) + ## tensor shapes + parser.add_argument("-shape", type=int, nargs=3, default=(32, 128, 64), help='Tensor shape in the form of M,N,K') + parser.add_argument("-plot", type=str, default="blocked", choices=['blocked', 'dot', 'wmma', 'lds'], + help='choose plot mode') + parser.add_argument("-nonKDim", type=int, default=32, choices=[16, 32], help='mfma instruction dim') + ## blocked layout parameters + parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4)) + parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4)) + parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) + parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) + ## LDS access parameters + parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16], help='number of elements per thread') + parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], + help='choose the LDS data layout') + parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], + help='choose LDS access mode') + ## wmma instruction layout parameter + parser.add_argument("-wave_size", type=int, default=32, choices=[32, 64], help='choose the wmma instruction mode') + + parser.add_argument("-o", type=str, default="myplot", help='output pdf file name (without surfix)') + parser.add_argument("-mfmaTrans", action='store_true', default=False, help='If set, then use mfma.trans layout') + parser.add_argument("-keep", action='store_true', default=False, help='If set, keep the generated .tex file') + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + + shape = args.shape + M = shape[0] + N = shape[1] + K = shape[2] + plot_mode = args.plot + mfmaNonKDim = args.nonKDim + kpack = args.kWidth + trans = 1 if args.mfmaTrans else 0 + ofilename = args.o + keepSrc = args.keep + + ldsLayout = args.lds_layout + ldsAccess = args.lds_access + + waveSize = args.wave_size + + sizePerThread = args.sizePerThread + threadsPerWarp = args.threadsPerWarp + warpsPerCTA = args.warpsPerCTA + order = args.order + + CTAShape = [] + if plot_mode == 'blocked': + print(f"Plotting tensor M={M},K={K} with blocked layout:") + print(f"sizePerThread={sizePerThread}", end=" ") + print(f"threadsPerWarp={threadsPerWarp}", end=" ") + print(f"warpsPerCTA={warpsPerCTA}", end=" ") + print(f"order={order}", end=" ") + CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0]) + CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) + + if plot_mode == 'dot': + mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" + mfma_trans_str = ".trans" if trans else "" + print(f"Plotting dot operation with shapes M={M},N={N},K={K}") + print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ") + print(f"warpsPerCTA={warpsPerCTA}", end=" ") + CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) + CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) + + if plot_mode == 'blocked' or plot_mode == 'dot': + print(f"CTAShape={CTAShape}") + assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" + + if plot_mode == 'blocked': + assert K != 0 and CTAShape[1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K" + + if plot_mode == 'dot': + assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" + assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K" + + if plot_mode == 'lds': + print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}") + if ldsAccess == 'write': + print(f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}") + + with open("myplot.tex", 'w') as f_plot: + with open("tikzplot.tex") as file: + tikz_code = file.read() + + preamble_str = draw_preamble_cmd() + + draw_blockedLayout_str = draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) + + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack) + + draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) + + draw_wmma_str = draw_wmma_instr_cmd(waveSize) + + f_plot.write(preamble_str + "\n") + f_plot.write(tikz_code) + if plot_mode == 'blocked': + f_plot.write(draw_blockedLayout_str) + elif plot_mode == 'dot': + f_plot.write(draw_dotLayout_str) + elif plot_mode == 'lds': + f_plot.write(draw_lds_str) + elif plot_mode == 'wmma': + f_plot.write(draw_wmma_str) + + run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex") + print(f"plot saved in {ofilename}.pdf") + + ## Remove au files + os.remove(f"{ofilename}.aux") + os.remove(f"{ofilename}.log") + if not keepSrc: + os.remove("myplot.tex") + run_bash_command("rm -rf ./auto") + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/python/perf-kernels/tools/plot-layout/tikzplot.tex b/python/perf-kernels/tools/plot-layout/tikzplot.tex new file mode 100644 index 000000000000..d8441b042f02 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/tikzplot.tex @@ -0,0 +1,880 @@ +\newcommand{\drawBlockedWave}[5]{ + %% + %% Draw a wave coverage with blocked layout + %% + %% Wave TL: pre defined top-left coordinate of the wave + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\order}{#5} + + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \foreach \tid in {0,...,63}{ + \pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)} + \pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)} + \coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$); + \pgfmathsetmacro{\ratio}{\tidM*10} + + \ifthenelse{\tid = 0}{ + \draw [line width = 0.01mm, fill=red] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + }{ + \draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + } + } + \draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem); +} + +\newcommand{\drawBlockedCTA}[7]{ + %% + %% Draw a CTA coverage with blocked layout + %% + %% CTA TL: pre defined top-left coordinate of the CTA + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: warpsPerCTA[0] --> warpsPerCTAM + %% #6: warpsPerCTA[1] --> warpsPerCTAN + %% #7: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\order}{#7} + + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1} + + \coordinate (Wave TL) at (CTA TL); + \drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order} + \foreach \waveId in {0,...,\maxWaveId}{ + \ifthenelse{\order=1} + { + \pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)} + \pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)} + \pgfmathsetmacro{\rot}{0} + }{ + \pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)} + \pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)} + \pgfmathsetmacro{\rot}{90} + } + + \coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$); + \draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem) + node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId}; + } + + \draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem); +} + +\newcommand{\drawBlockedTensor}[8]{ + %% + %% Draw a tensor with blocked layout of the following parameters + %% sizePerThread[2] + %% threadsPerWarp[2] + %% warpsPerCTA[2] + %% order[2] + %% + %% TL: pre defined top-left coordinate of the tensor + %% \elem: pre defined variable + %% + %% #1: tensorShape[0] --> M + %% #2: tensorShape[1] --> N + %% #3: sizePerThread[0] --> sizePerThreadM + %% #4: sizePerThread[1] --> sizePerThreadN + %% #5: threadsPerWarp[0] --> threadsPerWarpM + %% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0] + %% #6: warpsPerCTA[0] --> warpsPerCTAM + %% #7: warpsPerCTA[1] --> warpsPerCTAN + %% #8: fastest changing dim --> order + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\sizePerThreadM}{#3} + \pgfmathsetmacro{\sizePerThreadN}{#4} + \pgfmathsetmacro{\threadsPerWarpM}{#5} + \pgfmathsetmacro{\warpsPerCTAM}{#6} + \pgfmathsetmacro{\warpsPerCTAN}{#7} + \pgfmathsetmacro{\order}{#8} + + \pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM} + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\CTARepM}{\M/\CTASizeM} + \pgfmathsetmacro{\CTARepN}{\N/\CTASizeN} + \pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1} + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)} + \pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$); + \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} + } + + \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {M=\M}; + \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {K=\N}; + + \def\zoomR{1.5} + \coordinate (zoomin BL) at ($(TL)+(0, .3)$); + + \foreach \hl in {0,...,\sizePerThreadM}{ + \draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0); + } + \foreach \vl in {0,...,\sizePerThreadN}{ + \draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR); + } + + \node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$}; + \node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN}; + + \draw [densely dotted] (TL) -- (zoomin BL); + \draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$); + \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); +} + +\newcommand{\drawBlockMFMALayoutLarge}[3]{ + %% + %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 + %% + %% block TL: pre-defined top-left coordinate of the block + %% \elem: pre defined variable + %% + %% #1: 1 for mfma.trans, 0 for normal mfma + %% #2: mfmaNonKDim + %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\trans}{#1} + \pgfmathsetmacro{\nonTrans}{1-#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\maxGID}{\groups-1} + \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} + \pgfmathsetmacro{\verbose}{#3} + \foreach \iVec in {0,...,\maxIVec} { + \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); + \foreach \tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\colID}{\tg+4} + \pgfmathsetmacro{\col}{\Colors[\colID]} + \foreach \tid in {0,...,\maxTID} { + \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) + node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid}; + } + } + } + } + \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); +} + + +\newcommand{\drawTensorMFMALayout}[6]{ + %% + %% Draw a tensor with mfma layout. + %% + %% C TL: pre defined top-left coordinates of the tensor + %% + %% #1: M + %% #2: N + %% #3: MFMA nonKDim + %% #4: warpsPerCTA[0] + %% #5: warpsPerCTA[1] + %% #6: 1 for mfma.trans, 0 for normal mfma + + \pgfmathsetmacro{\tensorShapeH}{#1} + \pgfmathsetmacro{\tensorShapeW}{#2} + \pgfmathsetmacro{\mfmaNonKDim}{#3} + \pgfmathsetmacro{\warpsPerCTAH}{#4} + \pgfmathsetmacro{\warpsPerCTAW}{#5} + \pgfmathsetmacro{\mfmaTrans}{#6} + + \coordinate (old TL) at (TL); + \coordinate (TL) at (C TL); + + + \pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH} + \pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW} + \pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1} + \pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim} + \pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim} + + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)} + \pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); + %% Draw a detailed view of wave0 in each CTA + \coordinate (block TL) at (CTA TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} + + \foreach \waveId in {0,...,\maxWaveId}{ + \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} + \pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)} + \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); + %% Inside the loop, only draw a rectangle + \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) + node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; + } + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem); + } + + \coordinate (TL) at (old TL); +} + +\newcommand{\drawMFMAOperand}[4]{ + %% + %% Draw one mfma operand + %% + %% mfma op TL: pre defined coordinates of the top-left + %% \elem: pre defined variable + %% + %% #1: mfmNonKDim + %% #2: kpack + %% #3: 0 for opA and 1 for opB + %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\nonKDim}{#1} + \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\kpack}{#2} + \pgfmathsetmacro{\opIdxA}{#3} + \pgfmathsetmacro{\opIdxB}{1-\opIdxA} + \pgfmathsetmacro{\verbose}{#4} + + \foreach \col/\tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\col}{\Colors[\tg]} + \foreach \tid in {0,...,\maxTID} { + % \pgfmathsetmacro{\ratio}{\tid*2.5+15} + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col] + ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) + rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col] + ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) + rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) + node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; + } + } + } +} + +\newcommand{\drawWaveOperand}[4]{ + %% + %% Draw the part of the tensor that is one operand of the wave + %% + %% Op TL: pre defined coordinates of the top-left of the operand + %% \elem: pre defined variable + %% + %% #1: K + %% #2: mfmNonKDim + %% #3: kpack + %% #4: 0 for opA and 1 for opB + + \pgfmathsetmacro{\K}{#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\kpack}{#3} + \pgfmathsetmacro{\opIdx}{#4} + \pgfmathsetmacro{\opIdxOther}{1-\opIdx} + + \coordinate (TL) at (Op TL); + + \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} + \pgfmathsetmacro{\maxKRepId}{\numKRep-1} + + \foreach \repId in {0,...,\maxKRepId}{ + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); + \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} + \draw [thick] (mfma op TL) rectangle + ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); + } +} + +\newcommand{\drawDotOperands}[7]{ + %% + %% Draw operand tensors of dot + %% + %% A TL and B TL: pre defined top-left coordinates of A and B tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + %% #7: kpack + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\kpack}{#7} + + %% operand A + \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} + \foreach \ctaId in {0,...,\maxCTAIdM}{ + \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); + \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); + } + \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); + + + %% operand B + \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} + \foreach \ctaId in {0,...,\maxCTAIdN}{ + \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); + \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); + } + \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); +} + + +\newcommand{\drawDot}[8]{ + %% + %% Draw C = dot A, B + %% + %% C TL: pre defined top-left coordinates of the result tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + %% #7: 1 for mfma.trans, 0 for normal mfma + %% #8: kpack + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\mfmaTrans}{#7} + \pgfmathsetmacro{\kpack}{#8} + \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} + + \pgfmathsetmacro{\gap}{\elem*20} + \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); + \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); + + \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} + + \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} + + %% Draw labels + \node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K}; + \node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M}; + + \node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K}; + \node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N}; + + \node [scale=\scale, above left] at (A TL) {A}; + \node [scale=\scale, above left] at (B TL) {B}; + \node [scale=\scale, above left] at (C TL) {C}; + + %% label nonKDim + \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; + \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; + %% label kpack + \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; + \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; +} + +\newcommand{\Colors}{{ + "red", + "YellowGreen", + "blue", + "Maroon", + "orange", + "cyan", + "magenta", + "brown", + "teal", + "purple", + "gray", + "Green", + "BlueGreen", + "violet", + "olive", + "darkgray", + }} + +\newcommand{\drawTensorLayoutGlobalMem}{ + %% + %% Draw tensor layout in global memory without any swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elem: per defined variable + %% \Colors: a pre defined array of 16 colors + %% + %% The following arguments are also expected to be pre defined + %% #1: M + %% #2: K + %% #3: vec: number of elements in a group + + \pgfmathsetmacro{\numVecK}{\K/\vec} + \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} + \pgfmathsetmacro{\drawM}{20} + + %% Draw the tensor, but only draw 32 rows + \draw (TL) rectangle ++(\K*\elem, -\drawM*\elem); + %% Draw detailed vec view of the tensor + \foreach \vecId in {0,...,\maxVecId}{ + + \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} + \pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)} + \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); + + \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} + \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} + \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} + \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} + + \draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + + } + %% M and K dim + \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M}; + \node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16}; + \node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K}; + %% label for vecSize + \def\vecR{1.5} + \coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$); + \pgfmathsetmacro{\maxVec}{\vec-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$); + \draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec}; +} + + + +\newcommand{\drawLDSLayoutTritonSwizzling}[2]{ + %% + %% Draw tensor layout in LDS with swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elem: per defined variable + %% \Colors: a pre defined array of 16 colors + %% + %% The following three arguments are expected to be pre defined + %% #1: M + %% #2: K + %% #3: vec: number of elements in a group + %% + %% #1: hasSwizzle, 0 means no swizzling and no padding, + %% 1 means optimal swizzling + %% 2 means padding + %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write + %% For ds_write access, the following variables are assumed to be pre defined + %% \sizePerThreadK + %% \sizePerThreadM + %% \threadsPerWarpK + + \pgfmathsetmacro{\hasSwizzle}{#1} + \pgfmathsetmacro{\accessMode}{#2} + \pgfmathsetmacro{\numVecK}{\K/\vec} + + %% Assuming fp16 data type + \pgfmathsetmacro{\LDSK}{64} + \pgfmathsetmacro{\numLDSVec}{\LDSK/\vec} + \pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)} + \pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)} + + \ifthenelse{\accessMode = 2}{ + %% \accessMode == 2, draw 8 rows + \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} + \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} + }{ + %% \accessMode == 0 or 1, draw 16 rows + \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} + \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} + } + + %% Parameters used for swizzling + \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} + %% perPhase = ceil(LDSK / K) + %% The number of the rows of the tensor that can share the same swizzling pattern + \pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)} + %% maxPhase: the total number of different swizzling patterns + \ifthenelse{\hasSwizzle=0}{ + %% When swizzling is disabled + \pgfmathsetmacro{\maxPhase}{1} + }{ + %% When vec is small enough, we want 16/perPhase different swizzling patterns + %% When vec is large, we can only have 64 / \vec different swizzling pattern at most + \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)} + } + + %% Draw the LDS + \draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem); + + %% Draw detailed vec view of LDS + \foreach \vecId in {0,...,\maxVecId}{ + \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} + \pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))} + \pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)} + %% vec color + \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} + \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} + \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} + \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} + + %% old vec coordinates + \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); + + %% new vec coordinates in LDS by swizzling + %% The following two conditions correspond to the relation between \LDSK and \K + \ifthenelse{\LDSK < \K}{ + \pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)} + \pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))} + }{ + \pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)} + \pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)} + } + %% + \pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))} + %% Compute the swizzled col id + \pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}} + + %% new vec coordinates in LDS by padding + \pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)} + \pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads} + \pgfmathsetmacro{\vecPadM}{int(\bankId/32)} + \pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))} + + \ifthenelse{\hasSwizzle = 2}{ + %% vec coordinates by padding + \coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$); + \pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)} + }{ + %% vec coordinates by swizzling + \coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$); + \pgfmathsetmacro{\tailBankId}{0} + } + + \ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ + \pgfmathsetmacro{\nextBanks}{\tailBankId-31} + \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} + \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) + rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + }{ + \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) + node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; + } + + %% ds_read + %% Highlight the elements the first 16 threads access in the first cycle + %% This is used to visualize bank conflicts + \ifthenelse{\accessMode = 1}{ + \ifthenelse{\vecCoordK = 0}{ + \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); + \draw (new vec TL) -- ++(\elem, -\elem); + \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); + }{} + }{} + + %% Draw ds_write pattern + \ifthenelse{\accessMode = 2}{ + %% First compute the coverage of the first 16 threads + \pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec} + \pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM} + %% Check conditions for the first 16 threads + \pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))} + \ifthenelse{\vecInThread=0}{ + \ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{ + \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); + \draw (new vec TL) -- ++(\elem, -\elem); + \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); + }{} + }{} + }{} + + %% Label the phase of each line if swizzling is used + \ifthenelse{\hasSwizzle = 2}{}{ + \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} + \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ + \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) + node [scale=.6*\scale, right] {\phase}; + }{} + } + } + + %% Draw boundary of 32 banks + %% Assume fp16 data type + \foreach \bank in {0,...,31}{ + \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) + node [scale=.6*\scale, right, black] {\bank}; + } + \draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); + \node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; + + \node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks}; + \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; + + %% label phase if swizzling is used + \ifthenelse{\hasSwizzle = 2}{}{ + \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; + } +} + +\newcommand{\drawMFMAInstr}[3]{ + %% + %% Draw layout of mfma instructions with tid labeled + %% + %% C TL: pre defined top-left coordinates of the output matrix + %% \elem: pre defined variable + %% + %% #1: mfmaNonKDim + %% #2: kpack + %% #3: mfmaTrans + \pgfmathsetmacro{\mfmaNonKDim}{#1} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\kpack}{#2} + \pgfmathsetmacro{\mfmaTrans}{#3} + \pgfmathsetmacro{\nonTrans}{1-#3} + + \pgfmathsetmacro{\gap}{\elem*5} + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); + \coordinate (mfma op TL) at (mfma opA TL); + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} + \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); + \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} + + \coordinate (block TL) at (C TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} + + %% Draw labels + \def\vecR{1.5} + \coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$); + \pgfmathsetmacro{\maxVec}{\kpack-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$); + \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack}; + + \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$); + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$); + \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack}; + + \node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC}; + \ifthenelse{\mfmaTrans=0}{ + \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA}; + \node [scale=\scale, above] at (mfma op TL) {opB}; + \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); + \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; + }{ + \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; + \node [scale=\scale, above] at (mfma op TL) {opA}; + \coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); + } + \draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem); + \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True}; + } +} + +\newcommand{\drawWMMAOperand}[3]{ + %% + %% Draw the layout of one operand of WMMA instruction + %% + %% #1: opIdx. 0 for opA, 1 for opB + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #3: mode. 0 for w32, 1 for w64 + %% + %% wmma op TL: pre defined top-left coordinates of the operand matrix + + \pgfmathsetmacro{\isOpB}{#1} + \pgfmathsetmacro{\isOpA}{1-\isOpB} + \pgfmathsetmacro{\verbose}{#2} + \pgfmathsetmacro{\isWLarge}{#3} + + \foreach \row in {0,...,15}{ + \pgfmathsetmacro{\ratio}{\row*5+15} + \coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$); + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \pgfmathsetmacro{\tidtwo}{int(\row+32)} + \pgfmathsetmacro{\tidthree}{int(\row+48)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree}; + }{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone}; + } + } +} + +\newcommand{\drawWMMAResult}[2]{ + %% + %% Draw layout of WMMA result tensor + %% + %% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #2: mode. 0 for w32, 1 for w64 + + \pgfmathsetmacro{\verbose}{#1} + \pgfmathsetmacro{\isWLarge}{#2} + + \pgfmathsetmacro{\numElem}{256} + \pgfmathsetmacro{\maxElemId}{\numElem-1} + + \foreach \elemId in {0,...,\maxElemId}{ + %% figure out the rowID + \pgfmathsetmacro{\rowId}{floor(\elemId/16)} + %% figure out the colID + \pgfmathsetmacro{\colId}{mod(\elemId,16)} + %% figure out the tid and color + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,64))} + \pgfmathsetmacro{\laneId}{mod(\elemId,64)} + }{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,32))} + \pgfmathsetmacro{\laneId}{mod(\elemId,32)} + } + %% figure out the color + \pgfmathsetmacro{\colorId}{floor(\laneId/16)} + \pgfmathsetmacro{\vecColor}{\Colors[\colorId]} + %% Coordinate + \coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$); + \draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem) + node [scale=.4*\scale, pos=.5] {t\tid}; + } + + +} + +\newcommand{\drawWMMAInstr}[2]{ + %% + %% Draw wmma instruction layouts 16x16x16 + %% + %% #1: mode. 0 for w32, 1 for w64 + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% + %% C TL: pre defined top-left coordinates of output matrix + %% \elem: pre defined element size + + + \pgfmathsetmacro{\isWLarge}{#1} + \pgfmathsetmacro{\verbose}{#2} + + \pgfmathsetmacro{\gap}{\elem*2} + \coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$); + \coordinate (wmma opA TL) at (wmma op TL); + \drawWMMAOperand{0}{\verbose}{\isWLarge} + \coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$); + \drawWMMAOperand{1}{\verbose}{\isWLarge} + + \drawWMMAResult{1}{\isWLarge} + + %% labels + \pgfmathsetmacro{\gap}{\elem} + \node [above left, scale=\scale] at (wmma opA TL) {A}; + \node [above left, scale=\scale] at (wmma op TL) {B}; + \node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C}; + + %% A k dim + \node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16}; + \draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$); + \draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$); + + %% B K dim + \node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$); + \draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$); + + %% C M dim + \node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16}; + \draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$); + \draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$); + + %% C N dim + \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); + \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); +} diff --git a/python/perf-kernels/tune_gemm/README.md b/python/perf-kernels/tools/tune_gemm/README.md similarity index 100% rename from python/perf-kernels/tune_gemm/README.md rename to python/perf-kernels/tools/tune_gemm/README.md diff --git a/python/perf-kernels/tune_gemm/icache_flush.py b/python/perf-kernels/tools/tune_gemm/icache_flush.py similarity index 100% rename from python/perf-kernels/tune_gemm/icache_flush.py rename to python/perf-kernels/tools/tune_gemm/icache_flush.py diff --git a/python/perf-kernels/tune_gemm/matmul_kernel.py b/python/perf-kernels/tools/tune_gemm/matmul_kernel.py similarity index 100% rename from python/perf-kernels/tune_gemm/matmul_kernel.py rename to python/perf-kernels/tools/tune_gemm/matmul_kernel.py diff --git a/python/perf-kernels/tune_gemm/one_config.py b/python/perf-kernels/tools/tune_gemm/one_config.py similarity index 100% rename from python/perf-kernels/tune_gemm/one_config.py rename to python/perf-kernels/tools/tune_gemm/one_config.py diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tools/tune_gemm/tune_gemm.py similarity index 100% rename from python/perf-kernels/tune_gemm/tune_gemm.py rename to python/perf-kernels/tools/tune_gemm/tune_gemm.py diff --git a/python/perf-kernels/tune_gemm/utils/file_generator.py b/python/perf-kernels/tools/tune_gemm/utils/file_generator.py similarity index 100% rename from python/perf-kernels/tune_gemm/utils/file_generator.py rename to python/perf-kernels/tools/tune_gemm/utils/file_generator.py diff --git a/python/perf-kernels/tune_gemm/utils/utils.py b/python/perf-kernels/tools/tune_gemm/utils/utils.py similarity index 100% rename from python/perf-kernels/tune_gemm/utils/utils.py rename to python/perf-kernels/tools/tune_gemm/utils/utils.py From f80aed70afcebb3dc6a1101c503393a7a8d2957a Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Tue, 27 Aug 2024 20:26:39 +0000 Subject: [PATCH 16/20] Add rmsnorm kernel --- .../amd_perf_kernel_Integration_tests.yml | 4 +- python/perf-kernels/README.md | 4 + python/perf-kernels/rmsnorm.py | 208 ++++++++++++++++++ 3 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 python/perf-kernels/rmsnorm.py diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index 266018a2cf0c..61e44c4859d0 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -100,7 +100,7 @@ jobs: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} container: - image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 + image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4 options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root steps: - name: Checkout @@ -127,7 +127,9 @@ jobs: run: | pytest -vvv ./python/perf-kernels/flash-attention.py pytest -vvvv ./python/perf-kernels/softmax.py + pytest -vvv ./python/perf-kernels/rmsnorm.py - name: Run Perf Kernels Benchmark run: | python ./python/perf-kernels/flash-attention.py python ./python/perf-kernels/softmax.py + python ./python/perf-kernels/rmsnorm.py diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index 663f5333cc13..e81d5f4b8e56 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -73,3 +73,7 @@ used `tl.dot` with minimum block size equal to $16$. ## `softmax.py` Kernel that implements Softmax over a row of tensor. + +## `rmsnorm.py` + +Kernel that implements RMS Norm over a row of tensor. diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py new file mode 100644 index 000000000000..7aa247761cbd --- /dev/null +++ b/python/perf-kernels/rmsnorm.py @@ -0,0 +1,208 @@ +import argparse +import torch +import sys +import pytest + +import triton +import triton.language as tl + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_cuda_autotune_config(): + return [ + triton.Config({}, num_warps=4, num_stages=1), + triton.Config({}, num_warps=8, num_stages=1), + triton.Config({}, num_warps=16, num_stages=1), + ] + + +def get_hip_autotune_config(): + return [ + triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), + triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) +@triton.jit +def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, + BLOCK_SIZE: tl.constexpr): + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + for row_idx in tl.range(row_start, n_rows, row_step): + row_start_ptr = input_ptr + row_idx * input_row_stride + input_ptrs = row_start_ptr + col_offsets + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") + row_norm = row * row #square each value + row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1) + row_norm = row_norm / n_cols #divide by n_cols + row_norm = row_norm + epsilon #add epsilon + row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value + rms_norm = row * row_norm #multiply each x by normalization value + rms_norm = rms_norm * g #element wise multiplication with g + + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + output_ptrs = tl.multiple_of(output_ptrs, (16, )) + tl.store(output_ptrs, rms_norm, mask=mask) + + +def rmsnorm(x, epsilon=1e-6): + n_rows, n_cols = x.shape + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + y = torch.empty_like(x, device='cuda') + g = torch.ones((1, n_cols), device='cuda') + + num_programs = n_rows + grid = lambda meta: (num_programs, ) + rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) + + return y + + +def run_rmsnorm(M, N): + torch.manual_seed(0) + x = torch.randn(M, N, device='cuda') + y_triton = rmsnorm(x) + + return y_triton + + +@pytest.mark.parametrize('M, N', [ + (1, 4), + (2, 10), + (8192, 4096), + (4096, 8192), + (1, 8192), + (873, 1245), +]) +def test_rmsnorm(M, N): + torch.manual_seed(0) + x = torch.randn(M, N, device='cuda') + y_triton = rmsnorm(x) + + rms_norm = torch.nn.RMSNorm(N, device='cuda') + y_torch = rms_norm(x) + + assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) + + +#Benchmark +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def torch_rmsnorm(x): + M, N = x.shape + rms_norm = torch.nn.RMSNorm(N, device='cuda') + y_torch = rms_norm(x) + + return y_torch + + +def run_benchmark(args): + config = [] + if (args.M_benchmark): + val = args.M_start + x_vals_list = [] + while val <= args.M_end: + x_vals_list.append(val) + val *= args.M_step + mn_args = {'N': args.N_start} + plot_name = str("rmsnorm-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) + + "-" + str(args.M_end) + "-" + str(args.M_step)) + x_names = ['M'] + else: + x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)] + mn_args = {'M': args.M_start} + x_names = ['N'] + plot_name = str("rmsnorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) + + "-" + str(args.N_end) + "-" + str(args.N_step)) + + dtype = arg_to_torch_dtype[args.dtype] + + print(plot_name) + config.append( + triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg='provider', + line_vals=['triton', 'torch'], + line_names=["Triton", "Torch"], + styles=[('blue', '-'), ('green', '-')], + ylabel="GB/s", + plot_name=plot_name, + args=mn_args, + )) + + @triton.testing.perf_report(config) + def benchmark(M, N, provider): + x = torch.randn(M, N, device='cuda', dtype=dtype) + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + if provider == 'torch': + ms = triton.testing.do_bench(lambda: torch_rmsnorm(x)) + if provider == 'triton': + ms = triton.testing.do_bench(lambda: rmsnorm(x)) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms) + + benchmark.run(save_path=".", show_plots=True, print_data=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark RMSNorm", + allow_abbrev=False, + ) + + parser.add_argument('-M', "--M_start", default="1", type=int) + parser.add_argument('-Ms', "--M_step", default="2", type=int) #This is multiplicative step + parser.add_argument('-Me', "--M_end", default="512", type=int) + parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool) + + parser.add_argument('-N', "--N_start", default="8192", type=int) + parser.add_argument('-Ns', "--N_step", default="1024", type=int) + parser.add_argument('-Ne', "--N_end", default="32768", type=int) + + parser.add_argument('-d', "--dtype", default="fp16") + parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) + + return parser.parse_args() + + +def main(): + args = parse_args() + if args.no_benchmark: + run_rmsnorm(args.M_start, args.N_start) + else: + run_benchmark(args) + + +if __name__ == "__main__": + sys.exit(main()) From a782caf0152558d5aac5dc81392277a178007c98 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Fri, 13 Sep 2024 15:28:14 +0000 Subject: [PATCH 17/20] Online softmax implementation --- python/perf-kernels/softmax.py | 71 +++++++++++++++++----------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index bd00f24c42fc..60eefb91986f 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -5,7 +5,6 @@ import triton import triton.language as tl -from triton.runtime import driver def is_cuda(): @@ -52,43 +51,52 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, - BLOCK_SIZE: tl.constexpr): +def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + row_start = tl.program_id(0) - row_step = tl.num_programs(0) - col_offsets = tl.arange(0, BLOCK_SIZE) - mask = col_offsets < n_cols - for row_idx in tl.range(row_start, n_rows, row_step): - row_start_ptr = input_ptr + row_idx * input_row_stride + row_idx = row_start + + #loop 1, find max and sum + m = -float('inf') #Initial value of max + row_sum = 0.0 + row_start_ptr = input_ptr + row_idx * input_row_stride + for b in tl.range(0, n_cols, BLOCK_SIZE): + col_offsets = b + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + m_p = tl.max(row_block, axis=0) #find block max + m_p = tl.maximum(m, m_p) #Find new max across all blocks so far + row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum + row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + m = m_p #save max + + output_row_start_ptr = output_ptr + row_idx * output_row_stride + #Loop 2 + for b in tl.range(0, n_cols, BLOCK_SIZE): + col_offsets = b + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - output_row_start_ptr = output_ptr + row_idx * output_row_stride + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + #subtract, exponentiate and divide by sum + softmax_output = tl.exp(row_block - m) / row_sum + #store output_ptrs = output_row_start_ptr + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) tl.store(output_ptrs, softmax_output, mask=mask) -device = torch.cuda.current_device() -properties = driver.active.utils.get_device_properties(device) -NUM_SM = properties["multiprocessor_count"] - - def softmax(x): n_rows, n_cols = x.shape - BLOCK_SIZE = triton.next_power_of_2(n_cols) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) y = torch.empty_like(x) - #Persistent kernel. Simply, set num of programs equal to number of streaming multi-processors - num_programs = min(NUM_SM, n_rows) + num_programs = n_rows grid = lambda meta: (num_programs, ) - softmax_kernel[grid]( + softmax_kernel_online[grid]( y, x, x.stride(0), @@ -111,17 +119,8 @@ def run_softmax(M, N): #pytest -@pytest.mark.parametrize('M, N', [ - (1823, 781), - (1, 1), - (128, 1), - (1, 128), - (8192, 8192), - (4096, 8192), - (359, 1), - (1, 359), - (1, 131072), -]) +@pytest.mark.parametrize('M, N', [(1823, 781), (1, 1), (128, 1), (1, 128), (8192, 8192), (4096, 8192), (359, 1), + (1, 359), (1, 131072), (1, 89999)]) def test_softmax(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda') From 893932515f202d2981f634e9c1b94b3f87609d50 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Mon, 23 Sep 2024 15:57:54 +0000 Subject: [PATCH 18/20] Use mask during load for Softmax --- python/perf-kernels/softmax.py | 51 ++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index 60eefb91986f..19764b2138e3 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -51,39 +51,66 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit -def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, - BLOCK_SIZE: tl.constexpr): +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, +): row_start = tl.program_id(0) row_idx = row_start #loop 1, find max and sum + loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 m = -float('inf') #Initial value of max row_sum = 0.0 row_start_ptr = input_ptr + row_idx * input_row_stride - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + for b in tl.range(0, loop_num): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block m = m_p #save max + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + m_p = tl.max(row_block, axis=0) #find block max + m_p = tl.maximum(m, m_p) #Find new max across all blocks so far + row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum + row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + m = m_p #save max + output_row_start_ptr = output_ptr + row_idx * output_row_stride #Loop 2 - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 + for b in tl.range(0, loop_num): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum softmax_output = tl.exp(row_block - m) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output, mask=mask) + tl.store(output_ptrs, softmax_output) + + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + #subtract, exponentiate and divide by sum + softmax_output = tl.exp(row_block - m) / row_sum + #store + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) def softmax(x): @@ -96,7 +123,7 @@ def softmax(x): num_programs = n_rows grid = lambda meta: (num_programs, ) - softmax_kernel_online[grid]( + softmax_kernel[grid]( y, x, x.stride(0), From ba9a70c1feb52422e196e87222f591cb709c4257 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Wed, 25 Sep 2024 01:47:25 +0000 Subject: [PATCH 19/20] Use tl.exp2 --- python/perf-kernels/softmax.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index 19764b2138e3..a549a87461ad 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -75,39 +75,40 @@ def softmax_kernel( row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far - row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum - row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum + row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block m = m_p #save max + #Last iteration with masked load/store col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far - row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum - row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum + row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block m = m_p #save max output_row_start_ptr = output_ptr + row_idx * output_row_stride #Loop 2 - loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 for b in tl.range(0, loop_num): col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum - softmax_output = tl.exp(row_block - m) / row_sum + softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output) + #Last iteration with masked load/store col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum - softmax_output = tl.exp(row_block - m) / row_sum + softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) From a542825f572f9c6b1b2e5989c0ab95a16ea4d010 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Wed, 25 Sep 2024 22:03:19 +0000 Subject: [PATCH 20/20] Use SW pipelining with loops --- python/perf-kernels/softmax.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index a549a87461ad..49d09283b6a3 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -39,6 +39,15 @@ def get_hip_autotune_config(): triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2), ] @@ -69,7 +78,7 @@ def softmax_kernel( m = -float('inf') #Initial value of max row_sum = 0.0 row_start_ptr = input_ptr + row_idx * input_row_stride - for b in tl.range(0, loop_num): + for b in tl.range(0, loop_num, num_stages=2): col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block @@ -92,7 +101,7 @@ def softmax_kernel( output_row_start_ptr = output_ptr + row_idx * output_row_stride #Loop 2 - for b in tl.range(0, loop_num): + for b in tl.range(0, loop_num, num_stages=2): col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block @@ -118,6 +127,7 @@ def softmax(x): n_rows, n_cols = x.shape MAX_FUSED_SIZE = 65536 // x.element_size() + #MAX_FUSED_SIZE = 16384 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) y = torch.empty_like(x)