From cf7dc90ee095ec6813bd91bbec5fc669d8d302af Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Tue, 16 Jul 2024 10:14:56 +0200 Subject: [PATCH] Feature/utils (#11) * migrate utils * minor fix * minor * minor * add utilities * format * play with tk --- .gitignore | 3 +- benchmarks/bench_bmm.py | 9 +- benchmarks/bench_matmul.py | 9 +- benchmarks/bench_sbmm.py | 14 +- docs/examples/04_tp.py | 52 ++ setup.py | 20 +- tests/ops/test_matmul.py | 1 - tests/ops/test_sbmm.py | 3 +- tests/ops/test_sbmm_tp.py | 29 +- tests/utils/test_sparsity.py | 0 triteia/csrc/flash_kittens/README.md | 3 + .../csrc/flash_kittens/common/base_ops.cuh | 301 ++++++++ .../csrc/flash_kittens/common/base_types.cuh | 216 ++++++ triteia/csrc/flash_kittens/common/common.cuh | 11 + .../flash_kittens/common/pyutils/__init__.py | 1 + .../common/pyutils/test_build_utils.py | 96 +++ .../common/pyutils/torch_helpers.cuh | 71 ++ triteia/csrc/flash_kittens/common/util.cuh | 229 ++++++ triteia/csrc/flash_kittens/kittens.cuh | 31 + .../csrc/flash_kittens/ops/group/group.cuh | 42 ++ .../flash_kittens/ops/group/memory/memory.cuh | 7 + .../group/memory/tile/global_to_register.cuh | 157 ++++ .../group/memory/tile/global_to_shared.cuh | 113 +++ .../group/memory/tile/shared_to_register.cuh | 107 +++ .../ops/group/memory/tile/tile.cuh | 8 + .../group/memory/vec/global_to_register.cuh | 43 ++ .../ops/group/memory/vec/global_to_shared.cuh | 68 ++ .../group/memory/vec/shared_to_register.cuh | 46 ++ .../ops/group/memory/vec/vec.cuh | 8 + .../flash_kittens/ops/group/shared/shared.cuh | 7 + .../ops/group/shared/tile/conversions.cuh | 27 + .../ops/group/shared/tile/maps.cuh | 433 +++++++++++ .../ops/group/shared/tile/reductions.cuh | 266 +++++++ .../ops/group/shared/tile/tile.cuh | 8 + .../ops/group/shared/vec/conversions.cuh | 27 + .../ops/group/shared/vec/maps.cuh | 237 ++++++ .../ops/group/shared/vec/vec.cuh | 9 + .../ops/group/wgmma/base/4x1.impl | 69 ++ .../ops/group/wgmma/base/4x16.impl | 189 +++++ .../ops/group/wgmma/base/4x2.impl | 77 ++ .../ops/group/wgmma/base/4x3.impl | 85 +++ .../ops/group/wgmma/base/4x4.impl | 93 +++ .../ops/group/wgmma/base/4x6.impl | 109 +++ .../ops/group/wgmma/base/4x8.impl | 125 ++++ .../ops/group/wgmma/base/base.cuh | 119 +++ .../flash_kittens/ops/group/wgmma/wgmma.cuh | 333 +++++++++ triteia/csrc/flash_kittens/ops/ops.cuh | 9 + .../flash_kittens/ops/warp/memory/memory.cuh | 10 + .../ops/warp/memory/tile/dsmem.cuh | 31 + .../warp/memory/tile/global_to_register.cuh | 182 +++++ .../ops/warp/memory/tile/global_to_shared.cuh | 160 ++++ .../warp/memory/tile/shared_to_register.cuh | 120 +++ .../ops/warp/memory/tile/tile.cuh | 15 + .../ops/warp/memory/tile/tma.cuh | 653 +++++++++++++++++ .../ops/warp/memory/util/dsmem.cuh | 136 ++++ .../ops/warp/memory/util/tma.cuh | 139 ++++ .../ops/warp/memory/util/util.cuh | 34 + .../ops/warp/memory/vec/dsmem.cuh | 31 + .../warp/memory/vec/global_to_register.cuh | 120 +++ .../ops/warp/memory/vec/global_to_shared.cuh | 52 ++ .../warp/memory/vec/shared_to_register.cuh | 128 ++++ .../flash_kittens/ops/warp/memory/vec/tma.cuh | 270 +++++++ .../flash_kittens/ops/warp/memory/vec/vec.cuh | 15 + .../ops/warp/register/register.cuh | 9 + .../ops/warp/register/tile/conversions.cuh | 323 +++++++++ .../ops/warp/register/tile/maps.cuh | 682 ++++++++++++++++++ .../ops/warp/register/tile/mma.cuh | 448 ++++++++++++ .../ops/warp/register/tile/reductions.cuh | 455 ++++++++++++ .../ops/warp/register/tile/tile.cuh | 11 + .../ops/warp/register/vec/conversions.cuh | 104 +++ .../ops/warp/register/vec/maps.cuh | 270 +++++++ .../ops/warp/register/vec/reductions.cuh | 180 +++++ .../ops/warp/register/vec/vec.cuh | 10 + .../flash_kittens/ops/warp/shared/shared.cuh | 9 + .../ops/warp/shared/tile/conversions.cuh | 60 ++ .../ops/warp/shared/tile/maps.cuh | 445 ++++++++++++ .../ops/warp/shared/tile/reductions.cuh | 277 +++++++ .../ops/warp/shared/tile/tile.cuh | 10 + .../ops/warp/shared/vec/conversions.cuh | 55 ++ .../ops/warp/shared/vec/maps.cuh | 250 +++++++ .../ops/warp/shared/vec/reductions.cuh | 159 ++++ .../flash_kittens/ops/warp/shared/vec/vec.cuh | 10 + triteia/csrc/flash_kittens/ops/warp/warp.cuh | 13 + .../flash_kittens/types/register/register.cuh | 11 + .../csrc/flash_kittens/types/register/rt.cuh | 190 +++++ .../flash_kittens/types/register/rt_base.cuh | 92 +++ .../types/register/rt_layout.cuh | 42 ++ .../csrc/flash_kittens/types/register/rv.cuh | 83 +++ .../flash_kittens/types/shared/shared.cuh | 10 + .../csrc/flash_kittens/types/shared/st.cuh | 212 ++++++ .../flash_kittens/types/shared/st_layout.cuh | 137 ++++ .../csrc/flash_kittens/types/shared/sv.cuh | 95 +++ triteia/csrc/flash_kittens/types/types.cuh | 51 ++ triteia/python/configs/gpus/specs.py | 2 +- triteia/python/nn/linear.py | 3 + triteia/python/ops/utils/generator.py | 74 ++ triteia/python/ops/utils/sparsity.py | 17 +- triteia/python/utils/benchmark.py | 60 +- triteia/python/utils/io.py | 5 +- triteia/python/utils/quant_utils.py | 55 +- triteia/tools/converters/convert_deltazip.py | 104 +-- triteia/tools/export_benchmark.py | 13 +- triteia/tools/verify_weights.py | 92 ++- 103 files changed, 10971 insertions(+), 173 deletions(-) create mode 100644 docs/examples/04_tp.py create mode 100644 tests/utils/test_sparsity.py create mode 100644 triteia/csrc/flash_kittens/README.md create mode 100644 triteia/csrc/flash_kittens/common/base_ops.cuh create mode 100644 triteia/csrc/flash_kittens/common/base_types.cuh create mode 100644 triteia/csrc/flash_kittens/common/common.cuh create mode 100644 triteia/csrc/flash_kittens/common/pyutils/__init__.py create mode 100644 triteia/csrc/flash_kittens/common/pyutils/test_build_utils.py create mode 100644 triteia/csrc/flash_kittens/common/pyutils/torch_helpers.cuh create mode 100644 triteia/csrc/flash_kittens/common/util.cuh create mode 100644 triteia/csrc/flash_kittens/kittens.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/group.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/memory.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/tile/shared_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/tile/tile.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/vec/shared_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/memory/vec/vec.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/tile/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/tile/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/tile/reductions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/tile/tile.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/vec/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/vec/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/shared/vec/vec.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x1.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x16.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x2.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x3.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x4.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x6.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/4x8.impl create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/base/base.cuh create mode 100644 triteia/csrc/flash_kittens/ops/group/wgmma/wgmma.cuh create mode 100644 triteia/csrc/flash_kittens/ops/ops.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/memory.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/dsmem.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/shared_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/tile.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/tile/tma.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/util/dsmem.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/util/tma.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/util/util.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/dsmem.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/shared_to_register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/tma.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/memory/vec/vec.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/register.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/tile/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/tile/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/tile/mma.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/tile/reductions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/tile/tile.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/vec/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/vec/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/vec/reductions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/register/vec/vec.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/shared.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/tile/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/tile/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/tile/reductions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/tile/tile.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/vec/conversions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/vec/maps.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/vec/reductions.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/shared/vec/vec.cuh create mode 100644 triteia/csrc/flash_kittens/ops/warp/warp.cuh create mode 100644 triteia/csrc/flash_kittens/types/register/register.cuh create mode 100644 triteia/csrc/flash_kittens/types/register/rt.cuh create mode 100644 triteia/csrc/flash_kittens/types/register/rt_base.cuh create mode 100644 triteia/csrc/flash_kittens/types/register/rt_layout.cuh create mode 100644 triteia/csrc/flash_kittens/types/register/rv.cuh create mode 100644 triteia/csrc/flash_kittens/types/shared/shared.cuh create mode 100644 triteia/csrc/flash_kittens/types/shared/st.cuh create mode 100644 triteia/csrc/flash_kittens/types/shared/st_layout.cuh create mode 100644 triteia/csrc/flash_kittens/types/shared/sv.cuh create mode 100644 triteia/csrc/flash_kittens/types/types.cuh diff --git a/.gitignore b/.gitignore index 7ba70a2..4beac97 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json # templates .github/templates/* -.vscode/ \ No newline at end of file +.vscode/ +.local \ No newline at end of file diff --git a/benchmarks/bench_bmm.py b/benchmarks/bench_bmm.py index e8c696e..96ba99f 100644 --- a/benchmarks/bench_bmm.py +++ b/benchmarks/bench_bmm.py @@ -4,7 +4,11 @@ gen_batched_sparse_quant4_NT, bmm_4bit_2_4, ) -from triteia.python.utils import timing_function, print_results_table, export_benchmark_results +from triteia.python.utils import ( + timing_function, + print_results_table, + export_benchmark_results, +) from triteia.python.configs.models.llama import llama_shapes flops_func = lambda b, m, n, k: 2 * b * m * n * k @@ -65,9 +69,10 @@ def w4_2_4_native(qweight, x, meta, scale): print_results_table(f"bmm b={b},m={m},n={n},k={k}", results) return results + if __name__ == "__main__": results = [] results.append(benchmark(2, 256, 32, 256)) results.append(benchmark(8, 4096, 32, 4096)) results.append(benchmark(8, 8192, 8, 8192)) - export_benchmark_results(results, ".local/bmm_bench.json") \ No newline at end of file + export_benchmark_results(results, ".local/bmm_bench.json") diff --git a/benchmarks/bench_matmul.py b/benchmarks/bench_matmul.py index eb7c937..68a52e4 100644 --- a/benchmarks/bench_matmul.py +++ b/benchmarks/bench_matmul.py @@ -1,6 +1,10 @@ import torch from triteia.python.ops import matmul_4bit_2_4, gen_sparse_quant4_NT -from triteia.python.utils import timing_function, print_results_table, export_benchmark_results +from triteia.python.utils import ( + timing_function, + print_results_table, + export_benchmark_results, +) flops_func = lambda m, n, k: 2 * m * n * k @@ -41,6 +45,7 @@ def w4_2_4_func(qweight, x, meta, scale): print_results_table(f"matmul m={m},n={n},k={k}", results) return results + if __name__ == "__main__": results = [] results.append(benchmark(256, 32, 256)) @@ -48,4 +53,4 @@ def w4_2_4_func(qweight, x, meta, scale): results.append(benchmark(4096, 4096, 4096)) results.append(benchmark(8192, 8192, 8192)) results.append(benchmark(16384, 128, 16384)) - export_benchmark_results(results, ".local/matmul_bench.json") \ No newline at end of file + export_benchmark_results(results, ".local/matmul_bench.json") diff --git a/benchmarks/bench_sbmm.py b/benchmarks/bench_sbmm.py index 975ed08..ee8c1af 100644 --- a/benchmarks/bench_sbmm.py +++ b/benchmarks/bench_sbmm.py @@ -5,7 +5,11 @@ sbmm_4bit_2_4_native, sbmm_4bit_2_4_multilaunch, ) -from triteia.python.utils import timing_function, print_results_table, export_benchmark_results +from triteia.python.utils import ( + timing_function, + print_results_table, + export_benchmark_results, +) from triteia.python.configs.models.llama import llama_shapes from triteia.python.ops.utils.generator import generate_model_distribution from triteia.python.ops import gen_batched_sparse_quant4_NT @@ -112,7 +116,7 @@ def w4_2_4_multilaunch_func(qweight, x, meta, scale, indices): [2, 4, 8, 16, 32, 64], [2, 4, 8, 16, 32, 64, 128], ] - distributions = ['uniform', 'zipf:1.5'] + distributions = ["uniform", "zipf:1.5"] ms = [4096, 8192] ns = [4096, 8192] for distribution in distributions: @@ -125,6 +129,8 @@ def w4_2_4_multilaunch_func(qweight, x, meta, scale, indices): benchmark(distribution, nr[i], nm[i][j], m, n) ) except Exception as e: - print(f"Failed to benchmark sbmm nr={nr[i]},nm={nm[i][j]},m={m},n={n}") + print( + f"Failed to benchmark sbmm nr={nr[i]},nm={nm[i][j]},m={m},n={n}" + ) print(e) - export_benchmark_results(results, ".local/sbmm_bench.json") \ No newline at end of file + export_benchmark_results(results, ".local/sbmm_bench.json") diff --git a/docs/examples/04_tp.py b/docs/examples/04_tp.py new file mode 100644 index 0000000..5c357fb --- /dev/null +++ b/docs/examples/04_tp.py @@ -0,0 +1,52 @@ +import torch +from triteia.python.ops import gen_sparse_quant4_NT, matmul_4bit_2_4 +from triteia.python.ops.utils.generator import torch_weight_to_sparse_marlin + +dev = "cuda" +n = 1 +m = 12288 +k = 6144 +groupsize = -1 +tp_size = 8 + +x = torch.randn((n, k), dtype=torch.float16, device=dev) +weight_ref, qweight, scale, meta = gen_sparse_quant4_NT( + m, k, groupsize=groupsize, device=dev +) +print( + f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}" +) +fp16_output = torch.matmul(x, weight_ref) +qs_output = matmul_4bit_2_4(qweight, x, meta, scale) + +qweights_by_tp, scales_by_tp, metas_by_tp = torch_weight_to_sparse_marlin( + weight_ref, scale, tp_size=tp_size, chunk_by="column" +) +partial_outputs = [] +partial_fp16_outputs = [] +for i in range(tp_size): + tp_weight = weight_ref[:, i * m // tp_size : (i + 1) * m // tp_size].contiguous() + tp_scales = scale[:, i * m // tp_size : (i + 1) * m // tp_size].contiguous() + partial_output = matmul_4bit_2_4( + qweights_by_tp[i], x, metas_by_tp[i], scales_by_tp[i] + ) + + partial_outputs.append(partial_output) + partial_fp16_output = torch.matmul(x, tp_weight) + partial_fp16_outputs.append(partial_fp16_output) + +torch.manual_seed(0) +torch.backends.cudnn.deterministic = True +tp_output = torch.cat(partial_outputs, dim=1) +fp16_merged_output = torch.cat(partial_fp16_outputs, dim=1) + +print(f"max diff (quant): {torch.max(torch.abs(fp16_output - qs_output))}") +print( + f"mean diff (tp): {torch.max(torch.abs(tp_output - fp16_output)/torch.mean(torch.abs(fp16_output)))}" +) +print(tp_output - qs_output) +print(fp16_output - fp16_merged_output) +print( + f"mean diff (fp16): {torch.mean(torch.abs(fp16_output - fp16_merged_output)/torch.mean(torch.abs(fp16_output)))}" +) +# print(f"\n\n{tp_output}\n{qs_output}") diff --git a/setup.py b/setup.py index d698558..0f2a17d 100644 --- a/setup.py +++ b/setup.py @@ -5,15 +5,22 @@ from setuptools import find_packages, setup from torch.utils import cpp_extension + def get_compute_capability(): try: - compute_cap = os.popen("nvidia-smi --query-gpu=compute_cap --format=csv,noheader").read().strip().split("\n")[0] + compute_cap = ( + os.popen("nvidia-smi --query-gpu=compute_cap --format=csv,noheader") + .read() + .strip() + .split("\n")[0] + ) major, minor = compute_cap.split(".") return f"{major}{minor}" except Exception as e: print(f"Failed to detect compute capability: {e}") return None + def read(*paths, **kwargs): """Read the contents of a text file safely. >>> read("triteia", "VERSION") @@ -36,7 +43,8 @@ def read_requirements(path): for line in read(path).split("\n") if not line.startswith(('"', "#", "-", "git+")) ] - + + compute_cap = get_compute_capability() if compute_cap is None: raise ValueError("Failed to detect compute capability") @@ -63,7 +71,13 @@ def read_requirements(path): ], dlink=True, extra_compile_args={ - "nvcc": ["-O3", f"-arch=sm_{compute_cap}", "--ptxas-options=-v", "-dc", "-lineinfo"] + "nvcc": [ + "-O3", + f"-arch=sm_{compute_cap}", + "--ptxas-options=-v", + "-dc", + "-lineinfo", + ] }, extra_link_args=["-lcudadevrt", "-lcudart"], ), diff --git a/tests/ops/test_matmul.py b/tests/ops/test_matmul.py index 67d85e7..acd952f 100644 --- a/tests/ops/test_matmul.py +++ b/tests/ops/test_matmul.py @@ -28,7 +28,6 @@ def test_tiny(self): # self.run_problem(256, 16, 256, groupsize=-1) # self.run_problem(256, 16, 512, groupsize=-1) self.run_problem(5504, 5504, 5504, groupsize=-1) - # def test_llama(self): # bsz = 16 diff --git a/tests/ops/test_sbmm.py b/tests/ops/test_sbmm.py index dc45027..5afdeca 100644 --- a/tests/ops/test_sbmm.py +++ b/tests/ops/test_sbmm.py @@ -67,7 +67,7 @@ def run_problem( def test_tiny(self): self.run_problem("uniform", 10, 5, 256, 256) self.run_problem("zipf:1.5", 128, 2, 4096, 12288) - + # def test_llama(self): # nrs = [16, 32, 64, 128, 256] # nms = [[2,4,8,16], [2,4,8,16,32], [2,4,8,16,32,64], [2,4,8,16,32,64,128], [2,4,8,16,32,64,128,256]] @@ -79,5 +79,6 @@ def test_tiny(self): # for distribution in distributions: # self.run_problem(distribution, nr, nm, layer[0], layer[1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops/test_sbmm_tp.py b/tests/ops/test_sbmm_tp.py index bd9061e..8439ea8 100644 --- a/tests/ops/test_sbmm_tp.py +++ b/tests/ops/test_sbmm_tp.py @@ -42,25 +42,34 @@ def run_problem_column( qweights.append(qweight) scales.append(scale) metas.append(meta) - - fp16_partial_output = sbmm_16bit_forloop(weight_ref, x, indices, base_weight=None) + + fp16_partial_output = sbmm_16bit_forloop( + weight_ref, x, indices, base_weight=None + ) native_partial_output = sbmm_4bit_2_4_native( qweight, x, meta, scale, indices, base_weight=None ) ref_fp16_outputs.append(fp16_partial_output) outputs.append(native_partial_output) - + ref_fp16_final_outputs = torch.cat(ref_fp16_outputs, dim=1) final_outputs = torch.cat(outputs, dim=1) - + stacked_fp16_weights = torch.cat(ref_weights, dim=2) stacked_qweights = torch.cat(qweights, dim=2) stacked_scales = torch.cat(scales, dim=2) stacked_metas = torch.cat(metas, dim=1) - - stacked_fp16_output = sbmm_16bit_forloop(stacked_fp16_weights, x, indices, base_weight=None) + + stacked_fp16_output = sbmm_16bit_forloop( + stacked_fp16_weights, x, indices, base_weight=None + ) stacked_native_output = sbmm_4bit_2_4_native( - stacked_qweights, x, stacked_metas, stacked_scales, indices, base_weight=None + stacked_qweights, + x, + stacked_metas, + stacked_scales, + indices, + base_weight=None, ) self.assertLess( torch.mean(torch.abs(final_outputs - ref_fp16_final_outputs)) @@ -72,15 +81,15 @@ def run_problem_column( / torch.mean(torch.abs(ref_fp16_final_outputs)), 0.002, ) - + except torch.cuda.OutOfMemoryError as e: print(f"Out of memory, skipping nr={nr}, nm={nm}, m={m}, k={k}") finally: torch.cuda.empty_cache() def test_tiny(self): - self.run_problem_column("uniform", 10, 5, 256, 256, 2) - + self.run_problem_column("uniform", 10, 5, 256, 256, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/test_sparsity.py b/tests/utils/test_sparsity.py new file mode 100644 index 0000000..e69de29 diff --git a/triteia/csrc/flash_kittens/README.md b/triteia/csrc/flash_kittens/README.md new file mode 100644 index 0000000..719e9dc --- /dev/null +++ b/triteia/csrc/flash_kittens/README.md @@ -0,0 +1,3 @@ +# + +> This is a fork of [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/common/base_ops.cuh b/triteia/csrc/flash_kittens/common/base_ops.cuh new file mode 100644 index 0000000..4778670 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/base_ops.cuh @@ -0,0 +1,301 @@ +/** + * @file + * @brief Basic operations on generic types. + */ + +#pragma once + +#include +#include +#include "base_types.cuh" + +namespace kittens { + +/** + * @namespace base_ops + * + * @brief A namespace for operations on basic data types. + */ +namespace base_ops { + +/* ---------- CONST OPS ---------- */ + +/** + * @brief Represents the zero constant operation. + * + * This operation returns the zero value of the specified type. + * + * @tparam T The data type for which to return the zero value. + * @return The zero value of type T. + */ +struct zero { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::zero(); } +}; +/** + * @brief Represents the one constant operation. + * + * This operation returns the one value of the specified type. + * + * @tparam T The data type for which to return the one value. + * @return The one value of type T. + */ +struct one { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::one(); } +}; +/** + * @brief Represents the positive infinity constant operation. + * + * This operation returns the positive infinity value of the specified type. + * + * @tparam T The data type for which to return the positive infinity value. + * @return The positive infinity value of type T. + */ +struct pos_infty { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::pos_infty(); } +}; +/** + * @brief Represents the negative infinity constant operation. + * + * This operation returns the negative infinity value of the specified type. + * + * @tparam T The data type for which to return the negative infinity value. + * @return The negative infinity value of type T. + */ +struct neg_infty { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::neg_infty(); } +}; + + +/* ---------- UNARY OPS ---------- */ + +/** + * @brief Exponential function operation. + * + * This operation calculates the exponential of the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The exponential of the input value. + */ +struct exp { + template static __device__ inline T op(const T &x) { return exp(x); } +}; +template<> __device__ inline float exp::op (const float &x ) { return __expf(x); } +template<> __device__ inline float2 exp::op(const float2 &x) { return float2{__expf(x.x), __expf(x.y)}; } +template<> __device__ inline bf16 exp::op (const bf16 &x ) { return hexp(x); } +template<> __device__ inline bf16_2 exp::op(const bf16_2 &x) { return h2exp(x); } +/** + * @brief Natural log function operation. + * + * This operation calculates the natural logarithm of the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The natural logarithm of the input value. + */ +struct log { + template static __device__ inline T op(const T &x) { return log(x); } +}; +template<> __device__ inline float log::op (const float &x ) { return __logf(x); } +template<> __device__ inline float2 log::op(const float2 &x) { return float2{__logf(x.x), __logf(x.y)}; } +template<> __device__ inline bf16 log::op (const bf16 &x ) { return hlog(x); } +template<> __device__ inline bf16_2 log::op(const bf16_2 &x) { return h2log(x); } +/** + * @brief Absolute value operation. + * + * This operation calculates the absolute value of the input. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The absolute value of the input. + */ +struct abs { + template static __device__ inline T op(const T &x) { return abs(x); } +}; +template<> __device__ inline float abs::op (const float &x ) { return fabsf(x); } +template<> __device__ inline float2 abs::op(const float2 &x) { return float2{fabsf(x.x), fabsf(x.y)}; } +template<> __device__ inline bf16 abs::op (const bf16 &x ) { return __habs(x); } +template<> __device__ inline bf16_2 abs::op(const bf16_2 &x) { return __habs2(x); } +/** + * @brief Rectified Linear Unit (ReLU) operation. + * + * This operation applies the ReLU function to the input, which is the + * maximum of zero and the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The result of ReLU function applied to the input. + */ +struct relu { + template static __device__ inline T op(const T &x) { return max(x, base_types::constants::zero()); } +}; +template<> __device__ inline float relu::op (const float &x ) { return max(x, 0.f); } +template<> __device__ inline float2 relu::op(const float2 &x) { return float2{max(x.x, 0.f), max(x.y, 0.f)}; } +template<> __device__ inline bf16 relu::op (const bf16 &x ) { return __hmax(x, base_types::constants::zero()); } +template<> __device__ inline bf16_2 relu::op(const bf16_2 &x) { return __hmax2(x, base_types::constants::zero()); } +/** + * @brief Copy operation. + * + * This operation returns the input value unchanged. + * + * @tparam T The data type of the input and output values. + * @param a[in] The input value. + * @return The same value as the input. + */ +struct copy { // for non-compile-time setters. + template static __device__ inline T op(const T &a) { return a; } +}; + + +/* ---------- BINARY OPS ---------- */ + +/** + * @brief Copy2 operation. + * + * This operation returns the second input value unchanged. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value (ignored). + * @param b[in] The second input value. + * @return The same value as the second input. + */ +struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/ + template static __device__ inline T op(const T &a, const T &b) { return b; } +}; +/** + * @brief Sum operation. + * + * This operation calculates the sum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The sum of the input values. + */ +struct sum { + template static __device__ inline T op(const T &a, const T &b) { return a+b; } +}; +template<> __device__ inline float2 sum::op(const float2 &a, const float2 &b) { return float2{a.x+b.x, a.y+b.y}; } +template<> __device__ inline bf16 sum::op (const bf16 &a, const bf16 &b) { return __hadd(a, b); } +template<> __device__ inline bf16_2 sum::op(const bf16_2 &a, const bf16_2 &b) { return __hadd2(a, b); } +/** + * @brief Subtraction operation. + * + * This operation calculates the difference between two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The difference between the input values. + */ +struct sub { + template static __device__ inline T op(const T &a, const T &b) { return a-b; } +}; +template<> __device__ inline float2 sub::op(const float2 &a, const float2 &b) { return float2{a.x-b.x, a.y-b.y}; } +template<> __device__ inline bf16 sub::op (const bf16 &a, const bf16 &b) { return __hsub(a, b); } +template<> __device__ inline bf16_2 sub::op(const bf16_2 &a, const bf16_2 &b) { return __hsub2(a, b); } +/** + * @brief Multiplication operation. + * + * This operation calculates the product of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The product of the input values. + */ +struct mul { + template static __device__ inline T op(const T &a, const T &b) { return a*b; } +}; +template<> __device__ inline float2 mul::op(const float2 &a, const float2 &b) { return float2{a.x*b.x, a.y*b.y}; } +template<> __device__ inline bf16 mul::op (const bf16 &a, const bf16 &b) { return __hmul(a, b); } +template<> __device__ inline bf16_2 mul::op(const bf16_2 &a, const bf16_2 &b) { return __hmul2(a, b); } +/** + * @brief Division operation. + * + * This operation calculates the quotient of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The quotient of the input values. + */ +struct div { + template static __device__ inline T op(const T &a, const T &b) { return a/b; } +}; +template<> __device__ inline float2 div::op(const float2 &a, const float2 &b) { return float2{a.x/b.x, a.y/b.y}; } +template<> __device__ inline bf16 div::op (const bf16 &a, const bf16 &b) { return __hdiv(a, b); } +template<> __device__ inline bf16_2 div::op(const bf16_2 &a, const bf16_2 &b) { return __h2div(a, b); } // this op is a special snowflake +/** + * @brief Maximum operation. + * + * This operation calculates the maximum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The maximum of the input values. + */ + struct max { + template static __device__ inline T op(const T &a, const T &b) { return ::max(a, b); } +}; +template<> __device__ inline float2 max::op(const float2 &a, const float2 &b) { return float2{::max(a.x, b.x), ::max(a.y, b.y)}; } +template<> __device__ inline bf16 max::op (const bf16 &a, const bf16 &b) { return __hmax(a, b); } +template<> __device__ inline bf16_2 max::op(const bf16_2 &a, const bf16_2 &b) { return __hmax2(a, b); } +/** + * @brief Minimum operation. + * + * This operation calculates the minimum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The minimum of the input values. + */ +struct min { + template static __device__ inline T op(const T &a, const T &b) { return ::min(a, b); } +}; +template<> __device__ inline float2 min::op(const float2 &a, const float2 &b) { return float2{::min(a.x, b.x), ::min(a.y, b.y)}; } +template<> __device__ inline bf16 min::op (const bf16 &a, const bf16 &b) { return __hmin(a, b); } +template<> __device__ inline bf16_2 min::op(const bf16_2 &a, const bf16_2 &b) { return __hmin2(a, b); } + + +/* ---------- TERNARY OPS ---------- */ + +/** + * @brief Fused multiply-add operation A * B + C. + * + * This operation performs a fused multiply-add, computing (A * B) + C with only one rounding. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @param c[in] The third input value to be added. + * @return The result of the fused multiply-add operation. + */ +struct fma_AxBtC { + template static __device__ inline T op(const T &a, const T &b, const T &c) { + return sum::op(mul::op(a, b), c); + } +}; +/** + * @brief Fused multiply-add operation A * C + B. + * + * This operation performs a fused multiply-add, computing (A * C) + B with only one rounding. + * This is particularly useful for attention mechanisms in neural networks. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The third input value to be added. + * @param c[in] The second input value. + * @return The result of the fused multiply-add operation. + */ +struct fma_AxCtB { // this is the one needed for attention + template static __device__ inline T op(const T &a, const T &b, const T &c) { + return sum::op(mul::op(a, c), b); + } +}; + +} // namespace base_ops + +} // namespace kittens diff --git a/triteia/csrc/flash_kittens/common/base_types.cuh b/triteia/csrc/flash_kittens/common/base_types.cuh new file mode 100644 index 0000000..b7f84f4 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/base_types.cuh @@ -0,0 +1,216 @@ +/** + * @file + * @brief Declarations, manipulations, and wrappers for basic types. + * + * This file is a bunch of utilities for going back and forth between different types. + * + * Many of them are for the compiler, so as to clean up the code. It unfortunately + * seems necessary when we have types we really care about that are less than word width. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace kittens { + +/** + * @brief Bfloat16 floating-point type. + */ +using bf16 = __nv_bfloat16; +/** + * @brief Half-precision floating-point type. + */ +using half = __half; +/** + * @brief Packed word of two bfloat16 floating-point values. + */ +using bf16_2 = __nv_bfloat162; +/** + * @brief Packed word of two half-precision floating-point values. + */ +using half_2 = __half2; + +namespace ducks { +/** + * @namespace base_types + * + * @brief A namespace for concepts for basic data types. + */ +namespace base_types { + +template +concept T2 = std::is_same_v || std::is_same_v; // could add half_2 later if implemented. + +} // namespace base_types +} // namespace ducks + +/** + * @namespace base_types + * + * @brief A namespace for ThunderKittens basic data types. + */ +namespace base_types { + +/** + * @brief Provides compile-time constants for different types. + * + * @tparam T The type for which to provide constants. + */ +template struct constants { + /** + * @brief Zero + * @return Constexpr zero with type T + */ + static __device__ inline constexpr T zero() { return T{0}; } + /** + * @brief One + * @return Constexpr one with type T + */ + static __device__ inline constexpr T one() { return T{1}; } + /** + * @brief Positive infinity. Particularly useful for initializing before a min op. + * @return Constexpr positive infinity with type T + */ + static __device__ inline constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work. + /** + * @brief Negative infinity. Particularly useful for initializing before a max op. + * @return Constexpr negative infinity with type T + */ + static __device__ inline constexpr T neg_infty() { return T{-INFINITY}; } +}; +template<> struct constants { + static __device__ inline constexpr float2 zero() { return float2{0.f, 0.f}; } + static __device__ inline constexpr float2 one() { return float2{1.f, 1.f}; } + static __device__ inline constexpr float2 pos_infty() { return float2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr float2 neg_infty() { return float2{constants::neg_infty(), constants::neg_infty()}; } +}; +template<> struct constants { + static __device__ inline constexpr bf16 zero() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x0000)); } // unfortunately __float2bf16_rn is not constexpr + static __device__ inline constexpr bf16 one() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x3F80)); } + static __device__ inline constexpr bf16 pos_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x7F80)); } + static __device__ inline constexpr bf16 neg_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0xFF80)); } +}; +template<> struct constants { + static __device__ inline constexpr bf16_2 zero() { return bf16_2{constants::zero(), constants::zero()}; } + static __device__ inline constexpr bf16_2 one() { return bf16_2{constants::one(), constants::one()}; } + static __device__ inline constexpr bf16_2 pos_infty() { return bf16_2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr bf16_2 neg_infty() { return bf16_2{constants::neg_infty(), constants::neg_infty()}; } +}; +template<> struct constants { + static __device__ inline constexpr half zero() { return std::bit_cast<__half>(uint16_t(0x0000)); } + static __device__ inline constexpr half one() { return std::bit_cast<__half>(uint16_t(0x3C00)); } + static __device__ inline constexpr half pos_infty() { return std::bit_cast<__half>(uint16_t(0x7C00)); } + static __device__ inline constexpr half neg_infty() { return std::bit_cast<__half>(uint16_t(0xFC00)); } +}; +template<> struct constants { + static __device__ inline constexpr half_2 zero() { return half_2{constants::zero(), constants::zero()}; } + static __device__ inline constexpr half_2 one() { return half_2{constants::one(), constants::one()}; } + static __device__ inline constexpr half_2 pos_infty() { return half_2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr half_2 neg_infty() { return half_2{constants::neg_infty(), constants::neg_infty()}; } +}; + +/** + * @brief Provides information about packing of elements for a given type. + * + * @tparam T The type for which to provide packing information. + */ +template struct packing { + /** + * @brief The number of elements packed together. + * + * @return constexpr int representing number of elements within the type. + */ + static __device__ inline constexpr int num() { return 1; } + /** + * @brief Packs a single T element twice (replicated) into its packed type. + * + * @param i[in] The element to pack. + * @return The packed type. + */ + static __device__ inline constexpr T pack(const bf16 &i); +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using packed_type = bf16_2; + static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using packed_type = half_2; + static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using packed_type = float2; + static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = bf16; + static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = half; + static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = float; + static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } +}; + +/** + * @brief Provides templated functionality to convert between different types. + * + * @tparam T The target type for conversion. + * @tparam U The source type for conversion. + */ +template struct convertor { + /** + * @brief Converts a value of type U to type T. + * + * @param u[in] The value of type U to convert. + * @return T The converted value of type T. + */ + static __device__ inline T convert(const U & u) { + return (T)u; + } +}; +template<> struct convertor { + static __device__ inline float convert(const bf16 & u) { + return __bfloat162float(u); + } +}; +template<> struct convertor { + static __device__ inline bf16 convert(const float & u) { + return __float2bfloat16_rn(u); + } +}; +template<> struct convertor { + static __device__ inline float2 convert(const bf16_2 & u) { + return __bfloat1622float2(u); + } +}; +template<> struct convertor { + static __device__ inline bf16_2 convert(const float2 & u) { + return __float22bfloat162_rn(u); + } +}; + +} +} diff --git a/triteia/csrc/flash_kittens/common/common.cuh b/triteia/csrc/flash_kittens/common/common.cuh new file mode 100644 index 0000000..7a95a71 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/common.cuh @@ -0,0 +1,11 @@ +/** + * @file + * @brief A collection of common resources on which ThunderKittens depends. + */ + + +#pragma once + +#include "util.cuh" +#include "base_types.cuh" +#include "base_ops.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/common/pyutils/__init__.py b/triteia/csrc/flash_kittens/common/pyutils/__init__.py new file mode 100644 index 0000000..c88a742 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/pyutils/__init__.py @@ -0,0 +1 @@ +"""Just some build utils""" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/common/pyutils/test_build_utils.py b/triteia/csrc/flash_kittens/common/pyutils/test_build_utils.py new file mode 100644 index 0000000..37a95b3 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/pyutils/test_build_utils.py @@ -0,0 +1,96 @@ +import torch +from torch.utils.cpp_extension import load +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +tile = 16 +################## +# +# Extension help +# +# This is the commands for the pytorch jit... +# https://pytorch.org/tutorials/advanced/cpp_extension.html +import os +project_root = os.getenv("THUNDERKITTENS_ROOT") +if project_root is None: + print("There is no project root set (env: thunderkittens_root) did you run env.src?") + os._exit(-1) + +def _sources(name): return [f"{name}_frontend.cpp", f"{name}.cu"] +def jit_build(name, debug=False, gpu_type='4090'): + _cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--generate-line-info', '--restrict', + f"-I {project_root}"] + + if gpu_type == '4090': + _cuda_flags.append('-DKITTENS_4090') + _cuda_flags.append('-arch=sm_89') + elif gpu_type == 'H100': + _cuda_flags.append('-DKITTENS_HOPPER') + _cuda_flags.append('-arch=sm_90a') + elif gpu_type == 'A100': + _cuda_flags.append('-DKITTENS_A100') + _cuda_flags.append('-arch=sm_80') + + if(debug): _cuda_flags += ['-D__DEBUG_PRINT', '-g', '-G', '-D TORCH_USE_CUDA_DSA'] + return load(name=f"{name}", sources=_sources(name), + extra_cflags=[], + extra_cuda_cflags=_cuda_flags) + + +def cuda_extension(name, debug, gpu_type): + _cuda_flags = [ + '--use_fast_math', + '--generate-line-info', + '--restrict', '-std=c++20', + '--expt-relaxed-constexpr', + '--expt-extended-lambda', + '-Xcompiler=-fno-strict-aliasing', + '-MD', '-MT', '-MF', '-x', 'cu', '-lrt', '-lpthread', '-ldl', + '-lcuda', '-lcudadevrt', '-lcudart_static', '-lcublas', + f"-I {project_root}" + ] + + if gpu_type == '4090': + _cuda_flags.append('-DKITTENS_4090') + _cuda_flags.append('-arch=sm_89') + elif gpu_type == 'H100': + _cuda_flags.append('-DKITTENS_HOPPER') + _cuda_flags.append('-arch=sm_90a') + elif gpu_type == 'A100': + _cuda_flags.append('-DKITTENS_A100') + _cuda_flags.append('-arch=sm_80') + + if(debug): _cuda_flags += ['-D__DEBUG_PRINT', '-g', '-G'] + return CUDAExtension(f'{name}', + sources=_sources(name), + extra_compile_args={'cxx' : ['-std=c++20'], + 'nvcc' : ['-O3'] + _cuda_flags}, + libraries=['cuda']) + +def library_build(name, debug=False): + setup(name=f"{name}", + ext_modules=[cuda_extension(name)], + cmdclass={'build_ext': BuildExtension}) + +##### +# Helpers +# +def __eq(str, x,y, tol=1e-5, debug=False): + err = torch.abs(x-y).max() + pass_str = "pass" if err < tol else "fail" + print(f"{str} : {pass_str} [err={err:0.5f}]") + if(debug and (err > tol)): + print(f"x\n{x}") + print(f"y\n{y}") + print(f"diff\n{x-y}") + + return err <= tol + +def _rtile(b,n,d,dt): return torch.randn(b,n,d,device='cuda', dtype=dt)/(n*d) +def _rhtile(b,h,n,d,dt): return torch.randn(b,h,n,d,device='cuda', dtype=dt)/(n*d) +def _rones(b,n,d,dt): return torch.ones(b,n,d,device='cuda', dtype=dt) + +def print_tiles(str, t): + for i in range(t.size(0)): + for j in range(t.size(1)//tile): + print(f"{str} TILE batch={i} tile={j}") + print(f"{t[i,j*tile:(j+1)*tile,:]}") \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/common/pyutils/torch_helpers.cuh b/triteia/csrc/flash_kittens/common/pyutils/torch_helpers.cuh new file mode 100644 index 0000000..e06c59f --- /dev/null +++ b/triteia/csrc/flash_kittens/common/pyutils/torch_helpers.cuh @@ -0,0 +1,71 @@ +#include +#include +#include + +// using namespace nvcuda; + +// ******* +// ** BROADCAST API +// Constant broadcasts. +__device__ __nv_bfloat16* device_cast(c10::BFloat16* x) { return reinterpret_cast<__nv_bfloat16*>(x);} +__device__ half* device_cast(at::Half *x) { return reinterpret_cast(x); } +__device__ float* device_cast(float* x) { return x;} +__device__ const __nv_bfloat16* device_cast(const c10::BFloat16* x) { return reinterpret_cast(x);} +__device__ const half* device_cast(const at::Half *x) { return reinterpret_cast(x); } +__device__ const float* device_cast(const float* x) { return x;} + +// This is a dispatch helper macro for us, to match the device and the pytorch style. +// It's modeled after the AT_DISPATCH macros. +// Note: that when you pass in FUNC you usually write it with parens to help the preprocessor parse it. +// It will give you errors about more parameters than the two it was expecting, if not. +#define DISPATCH(t, FUNC)\ + switch (t.scalar_type()) {\ + case c10::ScalarType::BFloat16: {\ + using H = __nv_bfloat16;\ + using D = __nv_bfloat162;\ + using T = c10::BFloat16;\ + using ACCUM = wmma_accum;\ + FUNC;\ + }\ + break;\ + case c10::ScalarType::Half: {\ + using H = half;\ + using D = __half2;\ + using T = at::Half;\ + using ACCUM = wmma_accum;\ + FUNC;\ + }\ + break;\ + case c10::ScalarType::Float: {\ + using H = float;\ + using D = float2;\ + using T = float;\ + using ACCUM = wmma_accum_tf32;\ + FUNC;\ + }\ + break;\ + default:\ + TORCH_CHECK(false, "Unsupported type!");\ + } + +// copied from online tutorial. +#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) +template +void check(T err, char const* const func, char const* const file, + int const line) +{ + if (err != cudaSuccess) + { + std::cerr << "CUDA Runtime Error at: " << file << ":" << line + << std::endl; + std::cerr << cudaGetErrorString(err) << " " << func << std::endl; + //std::exit(EXIT_FAILURE); + } +} + +bool is_tile(torch::Tensor t) {return t.size(0) == kittens::TILE_DIM && t.size(1) == kittens::TILE_DIM;} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_TILE(x) CHECK_INPUT(x); TORCH_CHECK(is_tile(x), #x " must be a 16x16 tile") \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/common/util.cuh b/triteia/csrc/flash_kittens/common/util.cuh new file mode 100644 index 0000000..ef8fde7 --- /dev/null +++ b/triteia/csrc/flash_kittens/common/util.cuh @@ -0,0 +1,229 @@ +/** + * @file + * @brief General utilities for ThunderKittens. + */ + +#pragma once + +#include +#include +#include +#include + +/** + * @namespace kittens + * + * @brief The main namespace of ThunderKittens. + */ +namespace kittens { + +/* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */ + +/** + * @brief Tile dimension constant. + */ +constexpr int TILE_DIM{16}; +/** + * @brief Tile num elements constant calculated as TILE_DIM squared. + */ +constexpr int TILE_ELEMENTS{TILE_DIM*TILE_DIM}; +/** + * @brief Constant representing number of threads in a warp. + */ +constexpr int WARP_THREADS{32}; +/** + * @brief Constant representing number of threads in a warpgroup of four warps. + */ +constexpr int WARPGROUP_THREADS{128}; +/** + + * @brief Constant representing number of warps in a warpgroup of four warps. + */ +constexpr int WARPGROUP_WARPS{4}; +/** + + * @brief Get the warp ID of the current thread. + * @return The warp ID. + */ +__device__ __forceinline__ int warpid() { return threadIdx.x >> 5; } +/** + * @brief Get the warpgroup ID of the current thread. + * @return The warpgroup ID. + */ +__device__ __forceinline__ int warpgroupid() { return threadIdx.x >> 7; } +/** + * @brief Get the lane ID of the current thread within its warp. + * @return The lane ID. + */ +__device__ __forceinline__ int laneid() { return threadIdx.x & 0x1f; } + +#ifdef KITTENS_HOPPER +constexpr int MAX_SHARED_MEMORY = 227000; +#elif KITTENS_A100 +constexpr int MAX_SHARED_MEMORY = 164000; +#elif KITTENS_4090 +constexpr int MAX_SHARED_MEMORY = 100000; +#endif + +/* ---------- TYPE HELPERS ---------- */ + +/** + * @namespace ducks + * + * @brief ThunderKittens' namespace for template metaprogramming.. + * + * This includes primarily dummy types and concept wrappers, along + * with a few additional utilities. + */ +namespace ducks { + +/** + * @brief A type representing an empty default for a template. + */ +struct default_type {}; + +// This macro can't be done as a template, so it doesn't really have a location in kittens. +#define typeof(A) typename std::remove_const::type>::type + +} + +/* ---------- SHUFFLE UTILS ---------- */ + +/** + * @brief Mask constant for all active threads in a warp. + */ +static constexpr uint32_t MASK_ALL = 0xFFFFFFFF; + +/** + * @brief Perform a shuffle down operation on a packed type synchronously across a warp. + * @tparam T The type of the value to be shuffled. + * @param mask[in] The mask of active threads. + * @param f[in] The value to be shuffled. + * @param delta[in] The number of positions to shuffle down. + * @return The result of the shuffle operation. + */ +template +__device__ static inline T packed_shfl_down_sync(uint32_t mask, const T &f, int delta) { + return __shfl_down_sync(mask, f, delta); +} +template<> +__device__ inline float2 packed_shfl_down_sync(uint32_t mask, const float2 &f, int delta) { + float2 r; + r.x = __shfl_down_sync(mask, f.x, delta); + r.y = __shfl_down_sync(mask, f.y, delta); + return r; +} +/** + * @brief Perform a packed shuffle operation synchronously across a warp. + * @tparam T The type of the value to be shuffled. + * @param mask[in] The mask of active threads. + * @param f[in] The value to be shuffled. + * @param src[in] The source lane from which to shuffle. + * @return The result of the shuffle operation. + */ +template +__device__ static inline T packed_shfl_sync(uint32_t mask, const T &f, int src) { + return __shfl_sync(mask, f, src); +} +template<> +__device__ inline float2 packed_shfl_sync(uint32_t mask, const float2 &f, int src) { + float2 r; + r.x = __shfl_sync(mask, f.x, src); + r.y = __shfl_sync(mask, f.y, src); + return r; +} + +/* ---------- SHARED MEMORY UTILS ---------- */ + +// Joyously stolen from https://github.com/NVIDIA/cutlass/blob/5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101/include/cute/container/alignment.hpp#L51 +#if defined(__CUDACC__) +#define KITTENS_ALIGN_AS(n) __align__(n) +#else +#define KITTENS_ALIGN_AS(n) alignas(n) +#endif + +#ifdef KITTENS_HOPPER +#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(128) +#else +#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16) +#endif + +/** + * @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls. + */ +struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; }; +/** + * @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments. + * @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align. + */ +template // +struct shared_allocator { + int *ptr; + + private: + // Recursive template to generate N-dimensional array type + template + struct variadic_array; + template + struct variadic_array { + using type = typename variadic_array::type[first_dim]; + }; + template + struct variadic_array { + using type = A; + }; + template using variadic_array_t = typename variadic_array::type; + + template + __device__ inline void align_ptr() { + if constexpr (alignment > 0) { + uint64_t p = reinterpret_cast(ptr); + ptr = (int*)(p + (alignment-(p%alignment))); + } + } + + public: + /** + * @brief Construct a new shared allocator using a pointer to extern shared memory. + * @param[in] _ptr Pointer to the start of the extern shared memory. + */ + __device__ shared_allocator(int *_ptr): ptr(_ptr) {} + /** + * @brief Allocate shared memory for a single instance or N-dimensional array of type A. + * @tparam A The type of the object to allocate. + * @tparam dims... A list of dimensions for the N-dimensional array. + * @return Reference to the allocated object. + */ + template + __device__ inline variadic_array_t& allocate() { + align_ptr(); + using at = variadic_array_t; + at*p = reinterpret_cast(ptr); + ptr += sizeof(at)/sizeof(int); + return *p; + } + /** + * @brief Allocate shared memory for a single instance or N-dimensional array of type A. + * @tparam alignment An alignment to enforce for this particular object. + * @tparam A The type of the object to allocate. + * @tparam dims... A list of dimensions for the N-dimensional array. + * @return Reference to the allocated object. + */ + template + __device__ inline variadic_array_t& allocate() { + align_ptr(); + using at = variadic_array_t; + at*p = reinterpret_cast(ptr); + ptr += sizeof(at)/sizeof(int); + return *p; + } +}; +#ifdef KITTENS_HOPPER +/** + * @brief A wrapper for an allocator that enforces sufficient alignment to be used for TMA loads and stores. + */ +using tma_allocator = shared_allocator<128>; +using tma_swizzle_allocator = shared_allocator<1024>; // swizzled TMA modes require up to 1024 byte alignments :/ +#endif + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/kittens.cuh b/triteia/csrc/flash_kittens/kittens.cuh new file mode 100644 index 0000000..186bd71 --- /dev/null +++ b/triteia/csrc/flash_kittens/kittens.cuh @@ -0,0 +1,31 @@ +/** + * @file + * @brief The master header file of ThunderKittens. This file includes everything you need! + */ + +#pragma once + +#include "common/common.cuh" +#include "types/types.cuh" +#include "ops/ops.cuh" + + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// Lifting a fwe really commonly used parts of the hierarchy +// up to the main namespace to make user code more concise. + +namespace kittens { + +using row_l = ducks::rt_layout::row; +using col_l = ducks::rt_layout::col; + +using naive_l = ducks::st_layout::naive; +using swizzle_l = ducks::st_layout::swizzle; +using wgmma_swizzle_l = ducks::st_layout::wgmma_swizzle; +using wgmma_interleave_l = ducks::st_layout::wgmma_interleave; + +using warpgroup = group<4>; // special scope commonly used by SM_90 and later. + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/group.cuh b/triteia/csrc/flash_kittens/ops/group/group.cuh new file mode 100644 index 0000000..7a602ad --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/group.cuh @@ -0,0 +1,42 @@ +/** + * @file + * @brief An aggregate header of all group (multi-warp) operations defined by ThunderKittens + */ + +#pragma once + +#include + +#include "../../common/common.cuh" +#include "../../types/types.cuh" +#include "../warp/warp.cuh" // several group memory ops rely on underlying warp-scope ops + +// A "warpgroup" is a special group of 4 consecutive warps defined by NVIDIA for certain SM_90+ operations. +#define KITTENS_CHECK_WARPGROUP static_assert(N_WARPS==4, "PTX warpgroup (N_WARPS=4) function called from a non-warpgroup group."); + +// WGMMA relies on some template structures that cannot be specialized within the group struct, so we declare them in advance. +#ifdef KITTENS_HOPPER +#include "wgmma/base/base.cuh" +#endif + +namespace kittens { +/* +This is meant to be used with a `using group_N = kittens::group;` at the start of every kernel. +*/ +template +struct group { +static constexpr int GROUP_THREADS = N_WARPS * kittens::WARP_THREADS; // This alias produces nice parallelism. +__device__ static inline int laneid() { return threadIdx.x % GROUP_THREADS; } +__device__ static inline int warpid() { return laneid() / kittens::WARP_THREADS; } +__device__ static inline int groupid() { return threadIdx.x / GROUP_THREADS; } + +#include "memory/memory.cuh" +#include "shared/shared.cuh" + +#ifdef KITTENS_HOPPER +#include "wgmma/wgmma.cuh" +#endif + +}; + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/memory.cuh b/triteia/csrc/flash_kittens/ops/group/memory/memory.cuh new file mode 100644 index 0000000..0072d64 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/memory.cuh @@ -0,0 +1,7 @@ +/** + * @file + * @brief An aggregate header of colaborative group memory movement operations + */ + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_register.cuh b/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_register.cuh new file mode 100644 index 0000000..23cbf7b --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_register.cuh @@ -0,0 +1,157 @@ +/** + * @file + * @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back. + */ + +/** + * @brief Collaboratively loads data from a source array into row-major layout tiles. + * + * @tparam RT The row-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template +__device__ inline static void load(RT &dst, const U *src, const int row_stride) { + using T2 = RT::dtype; + using U2 = base_types::packing::packed_type; + int warp_laneid = threadIdx.x % 32; + const int row_offset = dst.rows*warpid(); + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row = row_offset + i*dst.tile_size + (warp_laneid / 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + 2*(warp_laneid % 4); + dst.tiles[i][j].data[0] = base_types::convertor::convert(*(U2*)(&src[(row+0)*row_stride + (col+0)])); + dst.tiles[i][j].data[2] = base_types::convertor::convert(*(U2*)(&src[(row+0)*row_stride + (col+8)])); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + 2*(warp_laneid % 4); + dst.tiles[i][j].data[1] = base_types::convertor::convert(*(U2*)(&src[(row+8)*row_stride + (col+0)])); + dst.tiles[i][j].data[3] = base_types::convertor::convert(*(U2*)(&src[(row+8)*row_stride + (col+8)])); + } + } +} +/** + * @brief Collaboratively loads data from a source array into column-major layout tiles. + * + * @tparam RT The column-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template +__device__ inline static void load(RT &dst, const U *src, const int row_stride) { + using T = base_types::packing::unpacked_type; + int warp_laneid = threadIdx.x % 32; + const int row_offset = dst.rows*warpid(); + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row = row_offset + i*dst.tile_size + 2*(warp_laneid % 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (warp_laneid / 4); + dst.tiles[i][j].data[0].x = base_types::convertor::convert(src[(row+0)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].x = base_types::convertor::convert(src[(row+0)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (warp_laneid / 4); + dst.tiles[i][j].data[0].y = base_types::convertor::convert(src[(row+1)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].y = base_types::convertor::convert(src[(row+1)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (warp_laneid / 4); + dst.tiles[i][j].data[2].x = base_types::convertor::convert(src[(row+8)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].x = base_types::convertor::convert(src[(row+8)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (warp_laneid / 4); + dst.tiles[i][j].data[2].y = base_types::convertor::convert(src[(row+9)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].y = base_types::convertor::convert(src[(row+9)*row_stride + (col+8)]); + } + } +} + + +/** + * @brief Collaboratively stores data from register tiles to a destination array in global memory with a row-major layout. + * + * @tparam RT The register tile type with a row-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template +__device__ inline static void store(U *dst, const RT &src, const int row_stride) { + using T2 = RT::dtype; + using U2 = base_types::packing::packed_type; + int warp_laneid = threadIdx.x % 32; + const int row_offset = src.rows*warpid(); + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row = row_offset + i*src.tile_size + (warp_laneid / 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + 2*(warp_laneid % 4); + *(U2*)(&dst[(row+0)*row_stride + (col+0)]) = base_types::convertor::convert(src.tiles[i][j].data[0]); + *(U2*)(&dst[(row+0)*row_stride + (col+8)]) = base_types::convertor::convert(src.tiles[i][j].data[2]); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + 2*(warp_laneid % 4); + *(U2*)(&dst[(row+8)*row_stride + (col+0)]) = base_types::convertor::convert(src.tiles[i][j].data[1]); + *(U2*)(&dst[(row+8)*row_stride + (col+8)]) = base_types::convertor::convert(src.tiles[i][j].data[3]); + } + } +} +/** + * @brief Collaboratively stores data from register tiles to a destination array in global memory with a column-major layout. + * + * @tparam RT The register tile type with a column-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template +__device__ inline static void store(U *dst, const RT &src, const int row_stride) { + using T = base_types::packing::unpacked_type; + int warp_laneid = threadIdx.x % 32; + const int row_offset = src.rows*warpid(); + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row = row_offset + i*src.tile_size + 2*(warp_laneid % 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (warp_laneid / 4); + dst[(row+0)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].x); + dst[(row+0)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (warp_laneid / 4); + dst[(row+1)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].y); + dst[(row+1)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].y); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (warp_laneid / 4); + dst[(row+8)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].x); + dst[(row+8)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (warp_laneid / 4); + dst[(row+9)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].y); + dst[(row+9)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].y); + } + } +} diff --git a/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_shared.cuh b/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_shared.cuh new file mode 100644 index 0000000..1ecddf3 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/tile/global_to_shared.cuh @@ -0,0 +1,113 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory. + */ + +template +__device__ static inline void load(ST &dst, const typename ST::dtype *src, const int row_stride) { + // each thread needs to do 1 call per width*height / N_WARPS + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % GROUP_THREADS; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = dst.cols / elem_per_memcpy; + int total_calls = (dst.height * dst.width + (N_WARPS-1)) / N_WARPS; // round up + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * GROUP_THREADS + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % dst.cols; + + if (i +__device__ static inline void store(typename ST::dtype *dst, const ST &src, const int row_stride) { + + int laneid = threadIdx.x % GROUP_THREADS; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = src.cols / elem_per_memcpy; + int total_calls = (src.height * src.width + (N_WARPS-1)) / N_WARPS; // round up + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * GROUP_THREADS + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % src.cols; + + if (i +__device__ static inline void load_async(ST &dst, const typename ST::dtype *src, const int row_stride, cuda::barrier &barrier) { + // each thread needs to do 1 call per width*height / N_WARPS + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % GROUP_THREADS; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = dst.cols / elem_per_memcpy; + int total_calls = (dst.height * dst.width + (N_WARPS-1)) / N_WARPS; // round up + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * GROUP_THREADS + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % dst.cols; + + if (i(sizeof(float4)), + barrier + ); + } +} +template +__device__ static inline void store_async(typename ST::dtype *dst, const ST &src, const int row_stride, cuda::barrier &barrier) { + // each thread needs to do 1 call per width*height/4 + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % GROUP_THREADS; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = src.cols / elem_per_memcpy; + int total_calls = (src.height * src.width + (N_WARPS-1)) / N_WARPS; // round up + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * GROUP_THREADS + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % src.cols; + + if (i(sizeof(float4)), + barrier + ); + } +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/tile/shared_to_register.cuh b/triteia/csrc/flash_kittens/ops/group/memory/tile/shared_to_register.cuh new file mode 100644 index 0000000..9334bac --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/tile/shared_to_register.cuh @@ -0,0 +1,107 @@ +/** + * @file + * @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back. + */ + +/** + * @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination register tile. + * @param src[in] The source shared tile. + */ +template +__device__ inline static void load(RT &dst, const ST &src) { + constexpr int height = ST::height; + constexpr int warp_height = RT::height; + static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS."); + static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height."); + static_assert(ST::width==RT::width, "Group load / store requires tile widths to match."); + int local_warpid = warpid(); + using T2 = RT::dtype; + using U = ST::dtype; + using T = base_types::packing::unpacked_type; + using U2 = base_types::packing::packed_type; + int warp_laneid = ::kittens::laneid(); + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if constexpr (std::is_same_v) { + // handle the row-major layout + int row = (local_warpid*warp_height + i)*dst.tile_size + (warp_laneid / 4); + int col = j*dst.tile_size + 2*(warp_laneid % 4); + dst.tiles[i][j].data[0] = base_types::convertor::convert(*(U2*)(&src[{row+0, col+0}])); + dst.tiles[i][j].data[1] = base_types::convertor::convert(*(U2*)(&src[{row+8, col+0}])); + dst.tiles[i][j].data[2] = base_types::convertor::convert(*(U2*)(&src[{row+0, col+8}])); + dst.tiles[i][j].data[3] = base_types::convertor::convert(*(U2*)(&src[{row+8, col+8}])); + } + else { + // handle the column-major layout + int row = (local_warpid*warp_height + i)*dst.tile_size + 2*(warp_laneid % 4); + int col = j*dst.tile_size + (warp_laneid / 4); + dst.tiles[i][j].data[0].x = base_types::convertor::convert(src[{row+0, col+0}]); + dst.tiles[i][j].data[0].y = base_types::convertor::convert(src[{row+1, col+0}]); + dst.tiles[i][j].data[1].x = base_types::convertor::convert(src[{row+0, col+8}]); + dst.tiles[i][j].data[1].y = base_types::convertor::convert(src[{row+1, col+8}]); + dst.tiles[i][j].data[2].x = base_types::convertor::convert(src[{row+8, col+0}]); + dst.tiles[i][j].data[2].y = base_types::convertor::convert(src[{row+9, col+0}]); + dst.tiles[i][j].data[3].x = base_types::convertor::convert(src[{row+8, col+8}]); + dst.tiles[i][j].data[3].y = base_types::convertor::convert(src[{row+9, col+8}]); + } + } + } +} + + +/** + * @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination shared tile. + * @param src[in] The source register tile. + */ +template +__device__ inline static void store(ST &dst, const RT &src) { + constexpr int height = ST::height; + constexpr int warp_height = RT::height; + static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS."); + static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height."); + static_assert(ST::width==RT::width, "Group load / store requires tile widths to match."); + int local_warpid = warpid(); + using T2 = RT::dtype; + using U = ST::dtype; + using T = base_types::packing::unpacked_type; + using U2 = base_types::packing::packed_type; + int warp_laneid = ::kittens::laneid(); + #pragma unroll + for(int i = 0; i < warp_height; i++) { + #pragma unroll + for(int j = 0; j < src.width; j++) { + if constexpr (std::is_same_v) { + // handle the row-major layout + int row = (local_warpid*warp_height + i)*src.tile_size + (warp_laneid / 4); + int col = j*src.tile_size + 2*(warp_laneid % 4); + *(U2*)(&dst[{row+0, col+0}]) = base_types::convertor::convert(src.tiles[i][j].data[0]); + *(U2*)(&dst[{row+8, col+0}]) = base_types::convertor::convert(src.tiles[i][j].data[1]); + *(U2*)(&dst[{row+0, col+8}]) = base_types::convertor::convert(src.tiles[i][j].data[2]); + *(U2*)(&dst[{row+8, col+8}]) = base_types::convertor::convert(src.tiles[i][j].data[3]); + } + else { + // handle the column-major layout + int row = (local_warpid*warp_height + i)*src.tile_size + 2*(warp_laneid % 4); + int col = j*src.tile_size + (warp_laneid / 4); + dst[{row+0, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[0].x); + dst[{row+1, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[0].y); + dst[{row+0, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[1].x); + dst[{row+1, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[1].y); + dst[{row+8, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[2].x); + dst[{row+9, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[2].y); + dst[{row+8, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[3].x); + dst[{row+9, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[3].y); + } + } + } +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/tile/tile.cuh b/triteia/csrc/flash_kittens/ops/group/memory/tile/tile.cuh new file mode 100644 index 0000000..8a9f4b0 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/tile/tile.cuh @@ -0,0 +1,8 @@ +/** + * @file + * @brief An aggregate header of group memory operations on tiles. + */ + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_register.cuh b/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_register.cuh new file mode 100644 index 0000000..c071957 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_register.cuh @@ -0,0 +1,43 @@ +/** + * @file + * @brief Functions for a warpgroup to collaboratively transfer data directly between global memory and registers and back. + */ + +/** + * @brief Collaboratively loads data into register vectors from a source array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the source array. + * @param[out] dst The destination register vector to load data into. + * @param[in] src The source array in global memory to load data from. + */ +template +__device__ inline static void load(RV &dst, const U *_src) { + using T2 = RV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + const U *src = &_src[warpid() * dst.outer_dim * kittens::TILE_DIM]; // pretend smaller, do single warp load. + + // Call warp level store + ::kittens::load(dst, src); +} +/** + * @brief Collaboratively stores data from register vectors to a destination array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register vector to store data from. + */ +template +__device__ inline static void store(U *_dst, const RV &src) { + using T2 = RV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + U *dst = &_dst[warpid() * src.outer_dim * kittens::TILE_DIM]; // pretend smaller, do single warp store. + + // Call warp level store + ::kittens::store(dst, src); +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_shared.cuh b/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_shared.cuh new file mode 100644 index 0000000..a6be28d --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/vec/global_to_shared.cuh @@ -0,0 +1,68 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory. + */ + +/** + * @brief Loads data from global memory into shared memory vector. + * + * This function loads data from a global memory location pointed to by `src` into a shared memory vector `dst`. + * It calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`. + * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp. + * + * @tparam SV Shared vector type, must satisfy ducks::sv::all concept. + * @param dst Reference to the shared vector where the data will be loaded. + * @param src Pointer to the global memory location from where the data will be loaded. + */ +template +__device__ static inline void load(SV &dst, const typename SV::dtype *src) { + constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr int total_calls = dst.length / elem_per_transfer; // guaranteed to divide + __syncwarp(); + #pragma unroll + for(int i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < dst.length) + *(float4*)&dst[i*elem_per_transfer] = *(float4*)&src[i*elem_per_transfer]; + } +} + +template +__device__ static inline void load_async(SV &dst, const typename SV::dtype *src, cuda::barrier &barrier) { + constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr int total_calls = dst.length / elem_per_transfer; // guaranteed to divide + __syncwarp(); + #pragma unroll + for(int i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < dst.length) { + cuda::memcpy_async( + (void*)&dst[i*elem_per_transfer], + (void*)&src[i*elem_per_transfer], + cuda::aligned_size_t<16>(sizeof(float4)), + barrier + ); + } + } +} + +/** + * @brief Stores data from a shared memory vector to global memory. + * + * This function stores data from a shared memory vector `src` to a global memory location pointed to by `dst`. + * Similar to the load function, it calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`. + * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp. + * + * @tparam SV Shared vector type, must satisfy ducks::sv::all concept. + * @param dst Pointer to the global memory location where the data will be stored. + * @param src Reference to the shared vector from where the data will be stored. + */ +template +__device__ static inline void store(typename SV::dtype *dst, const SV &src) { + constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr int total_calls = src.length / elem_per_transfer; // guaranteed to divide + __syncwarp(); + #pragma unroll + for(int i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < src.length) + *(float4*)&dst[i*elem_per_transfer] = *(float4*)&src[i*elem_per_transfer]; // lmao it's identical + } +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/vec/shared_to_register.cuh b/triteia/csrc/flash_kittens/ops/group/memory/vec/shared_to_register.cuh new file mode 100644 index 0000000..7d9ca37 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/vec/shared_to_register.cuh @@ -0,0 +1,46 @@ +/** + * @file + * @brief Functions for a group to collaboratively transfer data directly between shared memory and registers and back. + */ + +/** + * @brief Collaboratively load data from a shared vector into register vectors split across a warpgroup. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination register vector. + * @param src[in] The source shared vector. + */ +template +__device__ inline static void load(RV &dst, const SV &_src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + static_assert(_src.tiles == dst.outer_dim*N_WARPS);// confirm size correct + auto &src = subvec_inplace(_src, warpid()); // pretend it's smaller and do warp-level load + + ::kittens::load(dst, src); // warp-level +} + +/** + * @brief Collaboratively store data into a shared vector from register vectors split across a warpgroup. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination shared vector. + * @param src[in] The source register vector. + */ +template +__device__ inline static void store(SV &_dst, const RV &src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + static_assert(_dst.tiles == src.outer_dim*N_WARPS);// confirm size correct + auto &dst = subvec_inplace(_dst, warpid()); // pretend it's smaller and do warp-level load + + ::kittens::store(dst, src); // warp-level +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/memory/vec/vec.cuh b/triteia/csrc/flash_kittens/ops/group/memory/vec/vec.cuh new file mode 100644 index 0000000..ed58f58 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/memory/vec/vec.cuh @@ -0,0 +1,8 @@ +/** + * @file + * @brief An aggregate header of group memory operations on vectors. + */ + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/shared.cuh b/triteia/csrc/flash_kittens/ops/group/shared/shared.cuh new file mode 100644 index 0000000..6558b07 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/shared.cuh @@ -0,0 +1,7 @@ +/** + * @file + * @brief An aggregate header of group operations on data in shared memory + */ + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/tile/conversions.cuh b/triteia/csrc/flash_kittens/ops/group/shared/tile/conversions.cuh new file mode 100644 index 0000000..19335d5 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/tile/conversions.cuh @@ -0,0 +1,27 @@ +/** + * @file + * @brief Group conversions between different shared memory tile types. + */ + +/* ---------- COPIES ---------- */ + +/** + * @brief Copies data from one shared memory tile to another, potentially with different data types and layouts. + * + * @tparam T The data type of the destination tile. + * @tparam U The data type of the source tile. + * @tparam _height The height of the tile. + * @tparam _width The width of the tile. + * @tparam L1 The layout of the destination tile. + * @tparam L2 The layout of the source tile. + * @param[out] dst The destination tile. + * @param[in] src The source tile. + */ +template +__device__ static inline void copy(st &dst, const st &src) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i+=GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = base_types::convertor::convert(src[{row, col}]); + } +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/tile/maps.cuh b/triteia/csrc/flash_kittens/ops/group/shared/tile/maps.cuh new file mode 100644 index 0000000..5ea1f06 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/tile/maps.cuh @@ -0,0 +1,433 @@ +/** + * @file + * @brief Group maps on shared tiles. + */ + +/** + * @brief Performs a uniform unary operation on a tile. + * + * This function applies a given unary operation to each element of the source tile and stores the result in the destination tile. + * The operation is applied independently to each element, without considering its position or the values of neighboring elements. + * + * @tparam op The unary operation to be applied. Must be specialized to support operation on the data type of T. + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the unary operation is applied. + */ +template // T2, w, h can be inferred from dst as long as op is specialized +__device__ static inline void unary_map(T &dst, const T &src) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(src.data[i]); + } +} + +/** + * @brief Performs a uniform binary operation on a tile with a scalar parameter. + * + * This function applies a given binary operation to each element of the source tile and a scalar parameter, then stores the result in the destination tile. + * The operation is applied independently to each element, treating the scalar parameter as the second operand for each operation. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the scalar parameter. + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] param The scalar parameter to be used as the second operand in the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(src.data[i], param); + } +} + +/** + * @brief Performs a uniform binary operation on two tiles. + * + * This function applies a given binary operation to corresponding elements of two source tiles and stores the result in the destination tile. + * The operation is applied independently to each pair of elements, without considering their positions or the values of neighboring elements. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile to which the binary operation is applied. + * @param[in] rhs The second source tile to which the binary operation is applied. + */ +template +__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(lhs.data[i], rhs.data[i]); + } +} + +/** + * @brief Performs a row-wise binary operation on a tile with a vector. + * + * This function applies a given binary operation to each row of the source tile and the corresponding element of the source vector, + * then stores the result in the destination tile. The operation is applied independently to each row, using the vector element as + * the second operand for each element in the row. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @tparam V The type of the vector. Must have the same data type as T. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] vec The source vector containing the second operand for each row operation. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::rows, "Vector length must match the number of rows in the tile"); + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[row]); + } +} + +/** + * @brief Performs a column-wise binary operation on a tile with a vector. + * + * This function applies a given binary operation to each column of the source tile and the corresponding element of the source vector, + * then stores the result in the destination tile. The operation is applied independently to each column, using the vector element as + * the second operand for each element in the column. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @tparam V The type of the vector. Must have the same data type as T. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] vec The source vector containing the second operand for each column operation. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::cols, "Vector length must match the number of columns in the tile"); + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[col]); + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// All of the annoying qualifiers *should* be automatically inferred during compile-time. +// So, syntax should just be kittens::add_row(tile, colvec); + +// const maps +/** + * @brief Sets all elements of the destination tile to zero. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void zero(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to one. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void one(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to positive infinity. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to negative infinity. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_map(dst, dst); +} + +// unary maps +/** + * @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the exponential function is applied. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the natural logarithm function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the natural logarithm function is applied. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the absolute function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the absolute function is applied. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the rectified linear unit function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the rectified linear unit function is applied. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Copies the elements of the source tile to the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source data to be copied. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_map(dst, src); +} + +// uniform binary maps +/** + * @brief Finds the maximum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Finds the minimum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Adds each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Subtracts each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Multiplies each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Divides each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +// Row and col maps + +/** + * @brief Adds row values to each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param row_values[in] Column vector containing values to add to each row. + */ +template +__device__ static inline void add_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Subtracts row values from each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param row_values[in] Column vector containing values to subtract from each row. + */ +template +__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Multiplies each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param row_values[in] Column vector containing values to multiply each row by. + */ +template +__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Divides each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param row_values[in] Column vector containing values to divide each row by. + */ +template +__device__ static inline void div_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Broadcast a vector into into a tile's rows. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Column vector containing values to broadcast into rows. + */ +template +__device__ static inline void broadcast_row(T &dst, const V &row_values) { + row_map(dst, dst, row_values); +} + + +// col maps +/** + * @brief Adds column values to each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param col_values[in] Row vector containing values to add to each column. + */ +template +__device__ static inline void add_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Subtracts column values from each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param col_values[in] Row vector containing values to subtract from each column. + */ +template +__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Multiplies each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param col_values[in] Row vector containing values to multiply each column by. + */ +template +__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Divides each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param col_values[in] Row vector containing values to divide each column by. + */ +template +__device__ static inline void div_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Broadcast a vector into into a tile's columns. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Row vector containing values to broadcast into cols. + */ +template +__device__ static inline void broadcast_col(T &dst, const V &col_values) { + col_map(dst, dst, col_values); +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/tile/reductions.cuh b/triteia/csrc/flash_kittens/ops/group/shared/tile/reductions.cuh new file mode 100644 index 0000000..ed13051 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/tile/reductions.cuh @@ -0,0 +1,266 @@ +/** + * @file + * @brief Group reductions on shared tiles. + */ + +/** + * Performs row-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type with row layout. + * @param row_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int row = laneid(); row < src.rows; row += GROUP_THREADS) { + dtype accum = src[{row, 0}]; + #pragma unroll + for (int col = 1; col < src.cols; col++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + row_accum[row] = accum; + } else { + row_accum[row] = op::template op(src_accum[row], accum); + } + } +} + +/** + * Performs column-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the column accumulator. + * @tparam T The shared matrix type with column layout. + * @param col_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int col = laneid(); col < src.cols; col += GROUP_THREADS) { + dtype accum = src[{0, col}]; + #pragma unroll + for (int row = 1; row < src.rows; row++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + col_accum[col] = accum; + } else { + col_accum[col] = op::template op(src_accum[col], accum); + } + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +/** + * @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the product of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} + +/** + * @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the sum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the product of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/tile/tile.cuh b/triteia/csrc/flash_kittens/ops/group/shared/tile/tile.cuh new file mode 100644 index 0000000..9ecca64 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/tile/tile.cuh @@ -0,0 +1,8 @@ +/** + * @file + * @brief An aggregate header for group operations on shared tiles. + */ + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/vec/conversions.cuh b/triteia/csrc/flash_kittens/ops/group/shared/vec/conversions.cuh new file mode 100644 index 0000000..0b52769 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/vec/conversions.cuh @@ -0,0 +1,27 @@ +/** + * @file + * @brief Group conversions on shared vectors. + */ + +/** + * @brief Copies data from one shared vector to another, converting data types if necessary. + * + * This function copies data from the source shared vector `src` to the destination shared vector `dst`. + * If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it + * converts each element from the source data type to the destination data type using the appropriate + * converter before copying. + * + * @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept. + * @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept. + * @param[out] dst The destination shared vector. + * @param[in] src The source shared vector. + * @note The lengths of `src` and `dst` must be equal. This is enforced at compile time. + */ +template +__device__ static inline void copy(SV1 &dst, const SV2 &src) { + static_assert(dst.length == src.length, "Source and destination vectors must have the same length."); + #pragma unroll + for(int i = laneid(); i < dst.length; i+=GROUP_THREADS) { + dst[i] = base_types::convertor::convert(src[i]); + } +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/vec/maps.cuh b/triteia/csrc/flash_kittens/ops/group/shared/vec/maps.cuh new file mode 100644 index 0000000..5069b30 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/vec/maps.cuh @@ -0,0 +1,237 @@ +/** + * @file + * @brief Group maps on shared vectors. + */ + +/** + * @brief Applies a unary operation to each element of a shared memory vector. + * + * @tparam op Unary operation type. + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector in which to store the result. + * @param src[in] Source vector to apply the unary operation. + */ +template +__device__ static inline void unary_op(T &dst, const T &src) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(src[cur]); + } +} +/** + * @brief Perform a binary operation on two shared vectors. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vectors. + * @param dst[out] The destination vector where the result is stored. + * @param lhs[in] The left-hand side vector for the operation. + * @param rhs[in] The right-hand side vector for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(lhs[cur], rhs[cur]); + } +} +/** + * @brief Perform a binary operation on a shared vector and a scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(src[cur], param); + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// ---- const ops ---- + +/** + * @brief Sets all elements of a shared memory vector to zero. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to zero. + */ +template +__device__ static inline void zero(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to one. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to one. + */ +template +__device__ static inline void one(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to positive infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to positive infinity. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to negative infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to negative infinity. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_op(dst, dst); +} + +// ---- unary ops ---- + +/** + * @brief Copies the elements from one shared vector to another. + * + * @tparam T Shared vector type. + * @tparam U Type of the source vector. + * @param dst[out] Destination vector where the elements will be copied to. + * @param src[in] Source vector to copy the elements from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_op(dst, dst, src); // the second arg is ignored here. +} +/** + * @brief Applies the exponential function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the natural logarithm function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the logarithm function to. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the absolute value function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the absolute values will be stored. + * @param src[in] Source vector to apply the absolute value function to. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the ReLU values will be stored. + * @param src[in] Source vector to apply the ReLU function to. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_op(dst, src); +} + +// ---- binary ops ---- + +/** + * @brief Computes the element-wise maximum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the maximum values will be stored. + * @param lhs[in] First vector for the maximum operation. + * @param rhs[in] Second vector for the maximum operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise minimum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the minimum values will be stored. + * @param lhs[in] First vector for the minimum operation. + * @param rhs[in] Second vector for the minimum operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise sum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the sum values will be stored. + * @param lhs[in] First vector for the sum operation. + * @param rhs[in] Second vector for the sum operation. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise difference of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the difference values will be stored. + * @param lhs[in] First vector for the difference operation. + * @param rhs[in] Second vector for the difference operation. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise product of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the product values will be stored. + * @param lhs[in] First vector for the product operation. + * @param rhs[in] Second vector for the product operation. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise division of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the division values will be stored. + * @param lhs[in] First vector for the division operation. + * @param rhs[in] Second vector for the division operation. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/shared/vec/vec.cuh b/triteia/csrc/flash_kittens/ops/group/shared/vec/vec.cuh new file mode 100644 index 0000000..5c2237f --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/shared/vec/vec.cuh @@ -0,0 +1,9 @@ +/** + * @file + * @brief An aggregate header for group operations on shared vectors. + */ + +#include "conversions.cuh" +#include "maps.cuh" +// no group vector reductions as they would require additional shared memory and synchronization, and those side effects just aren't worth it. +// warp vector reductions should be plenty fast in 99.9% of situations. \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x1.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x1.impl new file mode 100644 index 0000000..49a53d3 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x1.impl @@ -0,0 +1,69 @@ +template +struct base<1, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 1, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %13, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "{%8, %9, %10, %11}, " \ + "%12, " \ + "p, 1, 1, %14;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 1, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, 1, %11, %12;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x16.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x16.impl new file mode 100644 index 0000000..3a85aa5 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x16.impl @@ -0,0 +1,189 @@ +template +struct base<16, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 16, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %133, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "{%128, %129, %130, %131}, " \ + "%132, " \ + "p, 1, 1, %134;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 16, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %130, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "%128, " \ + "%129, " \ + "p, 1, 1, %131, %132;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x2.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x2.impl new file mode 100644 index 0000000..b8d75a5 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x2.impl @@ -0,0 +1,77 @@ +template +struct base<2, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 2, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, 1, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 2, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, 1, %19, %20;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x3.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x3.impl new file mode 100644 index 0000000..c58aa0c --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x3.impl @@ -0,0 +1,85 @@ +template +struct base<3, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 3, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, 1, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 3, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, 1, %27, %28;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x4.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x4.impl new file mode 100644 index 0000000..4413f08 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x4.impl @@ -0,0 +1,93 @@ +template +struct base<4, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 4, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, 1, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 4, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, 1, %35, %36;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x6.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x6.impl new file mode 100644 index 0000000..ef37c45 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x6.impl @@ -0,0 +1,109 @@ +template +struct base<6, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 6, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, 1, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 6, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, 1, %51, %52;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x8.impl b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x8.impl new file mode 100644 index 0000000..6b93266 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/4x8.impl @@ -0,0 +1,125 @@ +template +struct base<8, trans_a, trans_b> { + __device__ static inline void rt_st( + rt_fl<1, 8, ducks::rt_layout::row> &dst, + const rt_base_bf & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, 1, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b) + ); + } + __device__ static inline void st_st( + rt_fl<1, 8, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, 1, %67, %68;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b) + ); + } + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/base/base.cuh b/triteia/csrc/flash_kittens/ops/group/wgmma/base/base.cuh new file mode 100644 index 0000000..fa3b3e0 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/base/base.cuh @@ -0,0 +1,119 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { +namespace ducks { +namespace wgmma { +template +concept normal = ( + std::is_same_v || + std::is_same_v +); +template +concept transposed = ( + std::is_same_v // || +); +} +} +namespace wgmma { + +// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor +__device__ static inline uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } + +// wgmma helpers +template +struct descriptor { static_assert("Asbtract wgmma descriptor struct should never be instantiated."); }; + +template +struct descriptor { + __device__ static inline uint64_t normal(uint64_t start_addr, int chunk_idx) { + uint64_t desc = 0x0000000000000000; + desc |= matrix_descriptor_encode(start_addr + chunk_idx*(128 * 2)); + desc |= matrix_descriptor_encode((uint64_t)128) << 16; + desc |= matrix_descriptor_encode((uint64_t)256*width) << 32; + return desc; + } + __device__ static inline uint64_t transposed(uint64_t start_addr, int chunk_idx) { + uint64_t desc = 0x0000000000000000; + desc |= matrix_descriptor_encode(start_addr + chunk_idx*(256*width * 2)); + desc |= matrix_descriptor_encode((uint64_t)256*width) << 16; + desc |= matrix_descriptor_encode((uint64_t)128) << 32; + return desc; + } +}; +template +struct descriptor { + __device__ static inline uint64_t normal(uint64_t start_addr, int chunk_idx) { + uint64_t desc = 0x0000000000000000; + if constexpr (width%4 == 0) { + desc |= matrix_descriptor_encode(start_addr + (chunk_idx%4)*32 + (chunk_idx/4)*height*2048); + desc |= matrix_descriptor_encode((uint64_t)16) << 16; + desc |= matrix_descriptor_encode((uint64_t)1024) << 32; + desc |= 1llu << 62; // set wgmma_swizzle mode + } + else if constexpr (width%2 == 0) { + desc |= matrix_descriptor_encode(start_addr + (chunk_idx%2)*32 + (chunk_idx/2)*height*1024); + desc |= matrix_descriptor_encode((uint64_t)16) << 16; + desc |= matrix_descriptor_encode((uint64_t)512) << 32; + desc |= 2llu << 62; // set wgmma_swizzle mode + } + else { + desc |= matrix_descriptor_encode(start_addr + chunk_idx*height*512); + desc |= matrix_descriptor_encode((uint64_t)16) << 16; + desc |= matrix_descriptor_encode((uint64_t)256) << 32; + desc |= 3llu << 62; // set wgmma_swizzle mode + } + return desc; + } +}; + +template +__device__ static inline uint64_t make_descriptor(const ST &tile, int chunk_idx) { + if constexpr (transpose) { + static_assert(ducks::wgmma::transposed, "Tile must have a transposable wgmma layout to be used here."); + return descriptor::transposed((uint64_t)(tile.data), chunk_idx); + } + else { + static_assert(ducks::wgmma::normal, "Tile must have a normal wgmma layout to be used here."); + return descriptor::normal((uint64_t)(tile.data), chunk_idx); + } +} +// templated wrapper for PTX +template +struct base { + __device__ static inline void rt_st( + rt_fl<1, width, ducks::rt_layout::row> &dst, + const rt_bf<1, 1, ducks::rt_layout::row> & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ); + __device__ static inline void st_st( + rt_fl<1, width, ducks::rt_layout::row> &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ); + template __device__ static inline uint64_t a_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } + template __device__ static inline uint64_t b_desc(const ST &tile, int chunk_idx) { + return make_descriptor(tile, chunk_idx); + } +}; + +#include "4x1.impl" +#include "4x2.impl" +#include "4x3.impl" +#include "4x4.impl" + +// can add bigger ones later, just annoying +// #include "4x5.impl" +#include "4x6.impl" +// #include "4x7.impl" +#include "4x8.impl" +#include "4x16.impl" + +} +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/group/wgmma/wgmma.cuh b/triteia/csrc/flash_kittens/ops/group/wgmma/wgmma.cuh new file mode 100644 index 0000000..7526d9d --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/group/wgmma/wgmma.cuh @@ -0,0 +1,333 @@ +/** + * @file + * @brief Warpgroup matrix-multiply accumulate operations. These ops are necessary to achieve full utilization on H100 GPUs. + */ + + + /* + ### OPTIONS: + + REG+SMEM -> REG + - mma_AB (accum) [DONE] + - mm_AB (reset) [DONE] + - mma_ABt (accum) [DONE] + - mm_ABt (reset) [DONE] + + SMEM+SMEM -> REG + - mma_AB (accum) [DONE] + - mm_AB (reset) [DONE] + - mma_ABt (accum) [DONE] + - mm_ABt (reset) [DONE] + - mma_AtB (accum) [DONE] + - mm_AtB (reset) [DONE] + - mma_AtBt (accum) [DONE] + - mm_AtBt (reset) [DONE] + +Note: mma is an alias for mma_AB and dot is an alias for mma_ABt +*/ + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix multiply-accumulate operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function multiplies a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The width of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AB(rt_fl &d, + const rt_bf &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + #pragma unroll + for(int n = 0; n < N_DIV_4; n++) { + rt_fl<1, M, ducks::rt_layout::row> &d_ref = subtile_inplace<1>(d, n); + base::rt_st( + d_ref, + a.tiles[n][0], + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.tiles[n][k], + base::b_desc(b, k), + 1 + ); + } + } +} +template +__device__ static inline void mm_AB(rt_fl &d, + const rt_bf &a, + const st_bf &b) { + mma_AB(d, a, b); +} + +template +__device__ static inline void mma_AB(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf<4, K, L_A> &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + base::st_st( + d, + base::a_desc(a, 0), + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + base::a_desc(a, k), + base::b_desc(b, k), + 1 + ); + } +} +template +__device__ static inline void mm_AB(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf<4, K, L_A> &a, + const st_bf &b) { + mma_AB(d, a, b); +} + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(rt_fl &d, + const rt_bf &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + #pragma unroll + for(int n = 0; n < N_DIV_4; n++) { + rt_fl<1, M, ducks::rt_layout::row> &d_ref = subtile_inplace<1>(d, n); + base::rt_st( + d_ref, + a.tiles[n][0], + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.tiles[n][k], + base::b_desc(b, k), + 1 + ); + } + } +} +template +__device__ static inline void mm_ABt(rt_fl &d, + const rt_bf &a, + const st_bf &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf<4, K, L_A> &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + base::st_st( + d, + base::a_desc(a, 0), + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + base::a_desc(a, k), + base::b_desc(b, k), + 1 + ); + } +} +template +__device__ static inline void mm_ABt(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf<4, K, L_A> &a, + const st_bf &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AtB(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + base::st_st( + d, + base::a_desc(a, 0), + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + base::a_desc(a, k), + base::b_desc(b, k), + 1 + ); + } +} +template +__device__ static inline void mm_AtB(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf &a, + const st_bf &b) { + mma_AtB(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A and B transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AtBt(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf &a, + const st_bf &b) { + KITTENS_CHECK_WARPGROUP + using base = kittens::wgmma::base; + base::st_st( + d, + base::a_desc(a, 0), + base::b_desc(b, 0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + base::a_desc(a, k), + base::b_desc(b, k), + 1 + ); + } +} +template +__device__ static inline void mm_AtBt(rt_fl<1, M, ducks::rt_layout::row> &d, + const st_bf &a, + const st_bf &b) { + mma_AtBt(d, a, b); +} + +/** + * @brief Synchronize the warp group and ensure that all writes to shared memory are visible to all threads in the warp group. + * + * This function acts as a fence for shared memory operations, ensuring that all previous writes are visible before proceeding. + * This function should be called before running wgmma::mma or wgmma::dot instructions. + * + * @tparam height The height of the matrix `dst`. + * @tparam width The width of the matrix `dst`. + * @param dst[in,out] The destination register-tile matrix to be synchronized. + */ +template +__device__ static inline void mma_fence(rt_fl &dst) { + KITTENS_CHECK_WARPGROUP + #pragma unroll + for(int i = 0; i < height; i++) { + #pragma unroll + for(int j = 0; j < width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + asm volatile("" : "+f"(dst.tiles[i][j].data[k].x) :: "memory"); + asm volatile("" : "+f"(dst.tiles[i][j].data[k].y) :: "memory"); + } + } + } + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +/** + * @brief Commit the current set of warp group matrix multiply accumulate calls. + */ + template // prevents static assert being instantiated unless called. +__device__ static inline void mma_commit_group() { + KITTENS_CHECK_WARPGROUP + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +/** + * @brief Wait for the warp group to reach a synchronization point. + * + * This function stalls the current warpgroup until enough WGMMA committed groups have been completed. + * + * @tparam N The number of remaining active WGMMA committed groups allowed. This will stall until the number of active groups is less than or equal to N. Defaults to 0. + */ +template +__device__ static inline void mma_async_wait() { + KITTENS_CHECK_WARPGROUP + asm volatile ("wgmma.wait_group.sync.aligned %0;" : : "n"(N) : "memory"); +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/ops.cuh b/triteia/csrc/flash_kittens/ops/ops.cuh new file mode 100644 index 0000000..db4f134 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/ops.cuh @@ -0,0 +1,9 @@ +/** + * @file + * @brief A collection of all of the operations that ThunderKittens defines. + */ + +#pragma once + +#include "warp/warp.cuh" +#include "group/group.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/memory.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/memory.cuh new file mode 100644 index 0000000..03dfe7f --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/memory.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header of warp memory operations, where a single warp loads or stores data on its own. + */ + +#pragma once + +// #include "util/util.cuh" +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/dsmem.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/dsmem.cuh new file mode 100644 index 0000000..f0dee6f --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/dsmem.cuh @@ -0,0 +1,31 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/shared/shared.cuh" +#include "../util/util.cuh" + +namespace kittens { +namespace dsmem { + +/** + * @brief Distributes data from a source shared tile to a destination shared tile across different thread blocks. + * + * This function wraps the distribute function by automatically calculating the number of bytes to be transferred + * based on the shared tile type and optional dimensions provided. It facilitates the distribution of data across + * different clusters or thread blocks in a device. + * + * @tparam ST The shared tile type. + * @tparam dims Variadic template parameter representing the dimensions of the array of shared tiles to be distributed. + * @param[in,out] dst_ Reference to the destination shared tile. + * @param[in,out] src_ Reference to the source shared tile. + * @param[in] cluster_size The size of the cluster or the number of thread blocks involved in the distribution. + * @param[in] dst_idx The index of the destination thread block within the cluster. + * @param[in,out] bar Reference to a barrier used for synchronization across thread blocks. + */ +template +__device__ static inline void distribute(ST &dst_, ST &src_, int cluster_size, int dst_idx, barrier& bar) { + distribute(dst_, src_, cluster_size, dst_idx, kittens::size_bytes, bar); // wrap with auto calculated bytes +} + +} +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_register.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_register.cuh new file mode 100644 index 0000000..05fae9d --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_register.cuh @@ -0,0 +1,182 @@ +/** + * @file + * @brief Functions for transferring data directly between global memory and registers and back. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Load data from a source array into a row-major layout tile. + * + * @tparam RT The row-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template +__device__ inline static void load(RT &dst, const U *src, const int row_stride) { + using T2 = RT::dtype; + using U2 = base_types::packing::packed_type; + int laneid = kittens::laneid(); + int warphalf = (laneid & 16) > 0; + int warphalflaneid = laneid % 16; + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row_0to3 = i*dst.tile_size + (warphalflaneid / 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + warphalf*8 + 2*(laneid % 4); + T2 transfers[2]; + transfers[0] = base_types::convertor::convert(*(U2*)(&src[(row_0to3+0)*row_stride + col])); + transfers[1] = base_types::convertor::convert(*(U2*)(&src[(row_0to3+4)*row_stride + col])); + transfers[1-warphalf] = packed_shfl_sync(MASK_ALL, transfers[1-warphalf], laneid^16); + dst.tiles[i][j].data[0] = transfers[0]; + dst.tiles[i][j].data[2] = transfers[1]; + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + warphalf*8 + 2*(laneid % 4); + T2 transfers[2]; + transfers[0] = base_types::convertor::convert(*(U2*)(&src[(row_0to3+ 8)*row_stride + col])); + transfers[1] = base_types::convertor::convert(*(U2*)(&src[(row_0to3+12)*row_stride + col])); + transfers[1-warphalf] = packed_shfl_sync(MASK_ALL, transfers[1-warphalf], laneid^16); + dst.tiles[i][j].data[1] = transfers[0]; + dst.tiles[i][j].data[3] = transfers[1]; + } + } +} +/** + * @brief Load data from a source array into a column-major layout tile. + * + * @tparam RT The column-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template +__device__ inline static void load(RT &dst, const U *src, const int row_stride) { + using T = base_types::packing::unpacked_type; + int laneid = threadIdx.x % 32; + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row = i*dst.tile_size + 2*(laneid % 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (laneid / 4); + dst.tiles[i][j].data[0].x = base_types::convertor::convert(src[(row+0)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].x = base_types::convertor::convert(src[(row+0)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (laneid / 4); + dst.tiles[i][j].data[0].y = base_types::convertor::convert(src[(row+1)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].y = base_types::convertor::convert(src[(row+1)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (laneid / 4); + dst.tiles[i][j].data[2].x = base_types::convertor::convert(src[(row+8)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].x = base_types::convertor::convert(src[(row+8)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size + (laneid / 4); + dst.tiles[i][j].data[2].y = base_types::convertor::convert(src[(row+9)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].y = base_types::convertor::convert(src[(row+9)*row_stride + (col+8)]); + } + } +} + +/** + * @brief Store data from a register tile to a destination array in global memory with a row-major layout. + * + * @tparam RT The register tile type with a row-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template +__device__ inline static void store(U *dst, const RT &src, const int row_stride) { + using T2 = RT::dtype; + using U2 = base_types::packing::packed_type; + int laneid = kittens::laneid(); + int warphalf = (laneid & 16) > 0; + int warphalflaneid = laneid % 16; + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row_0to3 = i*src.tile_size + (warphalflaneid / 4); + int row = i*src.tile_size + (laneid / 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + warphalf*8 + 2*(laneid % 4); + U2 transfers[2]; + transfers[0] = base_types::convertor::convert(src.tiles[i][j].data[0]); + transfers[1] = base_types::convertor::convert(src.tiles[i][j].data[2]); + transfers[1-warphalf] = packed_shfl_sync(MASK_ALL, transfers[1-warphalf], laneid^16); + *(U2*)(&dst[(row_0to3+0)*row_stride + col]) = transfers[0]; + *(U2*)(&dst[(row_0to3+4)*row_stride + col]) = transfers[1]; + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + warphalf*8 + 2*(laneid % 4); + U2 transfers[2]; + transfers[0] = base_types::convertor::convert(src.tiles[i][j].data[1]); + transfers[1] = base_types::convertor::convert(src.tiles[i][j].data[3]); + transfers[1-warphalf] = packed_shfl_sync(MASK_ALL, transfers[1-warphalf], laneid^16); + *(U2*)(&dst[(row_0to3+ 8)*row_stride + col]) = transfers[0]; + *(U2*)(&dst[(row_0to3+12)*row_stride + col]) = transfers[1]; + } + } +} +/** + * @brief Store data from a register tile to a destination array in global memory with a column-major layout. + * + * @tparam RT The register tile type with a column-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template +__device__ inline static void store(U *dst, const RT &src, const int row_stride) { + using T = base_types::packing::unpacked_type; + int laneid = threadIdx.x % 32; + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row = i*src.tile_size + 2*(laneid % 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (laneid / 4); + dst[(row+0)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].x); + dst[(row+0)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (laneid / 4); + dst[(row+1)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].y); + dst[(row+1)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].y); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (laneid / 4); + dst[(row+8)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].x); + dst[(row+8)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size + (laneid / 4); + dst[(row+9)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].y); + dst[(row+9)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].y); + } + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_shared.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_shared.cuh new file mode 100644 index 0000000..98d064d --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/global_to_shared.cuh @@ -0,0 +1,160 @@ +/** + * @file + * @brief Functions for transferring data directly between global and shared memory and back. + */ + +#pragma once + +#include + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +// ----------- ROW LAYOUTS ---------- + +/** + * @brief Loads bf16 data from global memory into a shared memory tile with a row layout. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination shared memory tile. + * @param[in] src The source global memory array. + * @param row_stride[in] The stride between rows in the source array. + */ +template +__device__ static inline void load(ST &dst, const bf16 *src, const int row_stride) { + // each thread needs to do 1 call per width*height + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % 32; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = dst.cols / elem_per_memcpy; + int total_calls = dst.height * dst.width; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * 32 + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % dst.cols; + + *(float4*)(&dst[{row, col}]) = *(float4*)(&src[row*row_stride + col]); + } +} +/** + * @brief Stores bf16 data from a shared memory tile with a row layout into global memory. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination global memory array. + * @param[in] src The source shared memory tile. + * @param row_stride[in] The stride between rows in the destination array. + */ +template +__device__ static inline void store(bf16 *dst, const ST &src, const int row_stride) { + + int laneid = threadIdx.x % 32; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = src.cols / elem_per_memcpy; + int total_calls = src.height * src.width; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * 32 + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % src.cols; + + *(float4*)(&dst[row*row_stride + col]) = *(float4*)(&src[{row, col}]); + } +} + +/** + * @brief Asynchronously loads bf16 data from global memory into a shared memory tile with a row layout using CUDA barriers. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination shared memory tile. + * @param[in] src The source global memory array. + * @param row_stride[in] The stride between rows in the source array. + * @param barrier[in,out] The CUDA barrier used for synchronization. + * + * @note This function expects 16-byte alignments. Otherwise, behavior is undefined. + */ +template +__device__ static inline void load_async(ST &dst, const bf16 *src, const int row_stride, cuda::barrier &barrier) { + // each thread needs to do 1 call per width*height + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % 32; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = dst.cols / elem_per_memcpy; + int total_calls = dst.height * dst.width; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * 32 + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % dst.cols; + + cuda::memcpy_async( + (void*)(&dst[{row, col}]), + (void*)(&src[row*row_stride + col]), + cuda::aligned_size_t<16>(sizeof(float4)), + barrier + ); + } +} +/** + * @brief Asynchronously stores bf16 data from a shared memory tile with a row layout into global memory using CUDA barriers. + * + * @tparam ST The type of the shared tile + * @param[out] dst The destination global memory array. + * @param[in] src The source shared memory tile. + * @param row_stride[in] The stride between rows in the destination array. + * @param barrier[in,out] The CUDA barrier used for synchronization. + * + * @note This function expects 16-byte alignments. Otherwise, behavior is undefined. + */ +template +__device__ static inline void store_async(bf16 *dst, const ST &src, const int row_stride, cuda::barrier &barrier) { + // each thread needs to do 1 call per width*height + // attempting to improve striping into dram + // each lane of the warp should store sequential into dram + + int laneid = threadIdx.x % 32; + + // we can handle this many rows each time we run a memcpy_async + int elem_per_memcpy = sizeof(float4)/sizeof(bf16); + int memcpy_per_row = src.cols / elem_per_memcpy; + int total_calls = src.height * src.width; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int idx = i * 32 + laneid; + + int row = idx / memcpy_per_row; + int col = (idx*elem_per_memcpy) % src.cols; + + cuda::memcpy_async( + (void*)(&dst[row*row_stride + col]), + (void*)(&src[{row, col}]), + cuda::aligned_size_t<16>(sizeof(float4)), + barrier + ); + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/shared_to_register.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/shared_to_register.cuh new file mode 100644 index 0000000..5f2d5cc --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/shared_to_register.cuh @@ -0,0 +1,120 @@ +/** + * @file + * @brief Functions for transferring data directly between shared memory and registers and back. + */ + +#pragma once + +#include + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +// These probably need to be redone to reduce bank conflicts. +// They currently work fine with xor layout but it should be +// possible to reduce their bank conflicts with other layouts too. + +/** + * @brief Load data from a shared tile into a register tile. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination register tile. + * @param src[in] The source shared tile. + */ +template +__device__ inline static void load(RT &dst, const ST &src) { + + static_assert(RT::height == ST::height, "register tile and shared tile must match height"); + static_assert(RT::width == ST::width, "register tile and shared tile must match width"); + + using T2 = RT::dtype; + using T = base_types::packing::unpacked_type; + using U = ST::dtype; + using U2 = base_types::packing::packed_type; + + int laneid = threadIdx.x % 32; + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if constexpr (std::is_same_v) { + // handle the row-major layout + int row = i*dst.tile_size + (laneid / 4); + int col = j*dst.tile_size + 2*(laneid % 4); + dst.tiles[i][j].data[0] = base_types::convertor::convert(*(U2*)(&src[{row+0, col+0}])); + dst.tiles[i][j].data[1] = base_types::convertor::convert(*(U2*)(&src[{row+8, col+0}])); + dst.tiles[i][j].data[2] = base_types::convertor::convert(*(U2*)(&src[{row+0, col+8}])); + dst.tiles[i][j].data[3] = base_types::convertor::convert(*(U2*)(&src[{row+8, col+8}])); + } + else { + // handle the column-major layout + int row = i*dst.tile_size + 2*(laneid % 4); + int col = j*dst.tile_size + (laneid / 4); + dst.tiles[i][j].data[0].x = base_types::convertor::convert(src[{row+0, col+0}]); + dst.tiles[i][j].data[0].y = base_types::convertor::convert(src[{row+1, col+0}]); + dst.tiles[i][j].data[1].x = base_types::convertor::convert(src[{row+0, col+8}]); + dst.tiles[i][j].data[1].y = base_types::convertor::convert(src[{row+1, col+8}]); + dst.tiles[i][j].data[2].x = base_types::convertor::convert(src[{row+8, col+0}]); + dst.tiles[i][j].data[2].y = base_types::convertor::convert(src[{row+9, col+0}]); + dst.tiles[i][j].data[3].x = base_types::convertor::convert(src[{row+8, col+8}]); + dst.tiles[i][j].data[3].y = base_types::convertor::convert(src[{row+9, col+8}]); + } + } + } +} + + +/** + * @brief Store data into a shared tile from a register tile. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination shared tile. + * @param src[in] The source register tile. + */ +template +__device__ inline static void store(ST &dst, const RT &src) { + + static_assert(RT::height == ST::height, "register tile and shared tile must match height"); + static_assert(RT::width == ST::width, "register tile and shared tile must match width"); + + using T2 = RT::dtype; + using T = base_types::packing::unpacked_type; + using U = ST::dtype; + using U2 = base_types::packing::packed_type; + + int laneid = threadIdx.x % 32; + #pragma unroll + for(int i = 0; i < src.height; i++) { + #pragma unroll + for(int j = 0; j < src.width; j++) { + if constexpr (std::is_same_v) { + // handle the row-major layout + int row = i*src.tile_size + (laneid / 4); + int col = j*src.tile_size + 2*(laneid % 4); + *(U2*)(&dst[{row+0, col+0}]) = base_types::convertor::convert(src.tiles[i][j].data[0]); + *(U2*)(&dst[{row+8, col+0}]) = base_types::convertor::convert(src.tiles[i][j].data[1]); + *(U2*)(&dst[{row+0, col+8}]) = base_types::convertor::convert(src.tiles[i][j].data[2]); + *(U2*)(&dst[{row+8, col+8}]) = base_types::convertor::convert(src.tiles[i][j].data[3]); + } + else { + // handle the column-major layout + int row = i*src.tile_size + 2*(laneid % 4); + int col = j*src.tile_size + (laneid / 4); + dst[{row+0, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[0].x); + dst[{row+1, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[0].y); + dst[{row+0, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[1].x); + dst[{row+1, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[1].y); + dst[{row+8, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[2].x); + dst[{row+9, col+0}] = base_types::convertor::convert(src.tiles[i][j].data[2].y); + dst[{row+8, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[3].x); + dst[{row+9, col+8}] = base_types::convertor::convert(src.tiles[i][j].data[3].y); + } + } + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/tile.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/tile.cuh new file mode 100644 index 0000000..48d1bc1 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/tile.cuh @@ -0,0 +1,15 @@ +/** + * @file + * @brief An aggregate header of warp memory operations on tiles, where a single warp loads or stores data on its own. + */ + +#pragma once + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#include "dsmem.cuh" +#endif \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/tile/tma.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/tile/tma.cuh new file mode 100644 index 0000000..010f350 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/tile/tma.cuh @@ -0,0 +1,653 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" +#include "../util/util.cuh" + +#include +#include + +namespace kittens { +namespace tma { + +/* ---------- Create tensor map descriptor (HOST) ---------- */ + +/** +* @brief Creates a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the ST template parameter. +* +* @tparam ST The source tensor type, which must be TMA-compatible. +* @tparam blocks_height The number of tiles present on the height axis in global memory. +* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1. +* @param tma_map Pointer to the CUtensorMap object to be initialized. +* @param src Pointer to the source tensor data in global memory. +*/ +template +__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const bf16 *src, int blocks_height, int blocks_width=1) { + static_assert(std::is_same_v); + + constexpr uint32_t tma_dim = ( + detail::st_type_naive_layout ? 2 : + detail::st_type_swizzle_layout ? 3 : + detail::st_type_wgmma_swizzle_layout ? 3 : + detail::st_type_wgmma_interleave_layout ? 4 : + -1 + ); + void *global_addr = (void*)(src); + + constexpr CUtensorMapDataType tma_format = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + constexpr CUtensorMapSwizzle tma_swizzle = ( + ST::swizzle_bytes == 32 ? CU_TENSOR_MAP_SWIZZLE_32B : + ST::swizzle_bytes == 64 ? CU_TENSOR_MAP_SWIZZLE_64B : + ST::swizzle_bytes == 128 ? CU_TENSOR_MAP_SWIZZLE_128B : + CU_TENSOR_MAP_SWIZZLE_NONE + ); + + uint64_t gmem_shape [4] = {0, 0, 0, 0}; + uint64_t gmem_stride[3] = {0, 0, 0}; + uint32_t smem_shape [4] = {0, 0, 0, 0}; + uint32_t smem_stride[4] = {1, 1, 1, 1}; + + uint64_t global_tile_height = blocks_height * ST::rows; + uint64_t global_tile_width = blocks_width * ST::cols; + constexpr uint64_t shared_tile_height = ST::rows; + constexpr uint64_t shared_tile_width = ST::cols; + + if constexpr (detail::st_type_naive_layout) { + gmem_shape[0] = global_tile_width; + gmem_shape[1] = global_tile_height; + + gmem_stride[0] = global_tile_width * sizeof(bf16); + + smem_shape[0] = shared_tile_width; + smem_shape[1] = shared_tile_height; + } + else if constexpr (detail::st_type_swizzle_layout) { + constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(bf16); + + gmem_shape[0] = swizzle_elements; + gmem_shape[1] = global_tile_width / swizzle_elements; + gmem_shape[2] = global_tile_height; + + gmem_stride[0] = ST::swizzle_bytes; + gmem_stride[1] = global_tile_width * sizeof(bf16); + + smem_shape[0] = swizzle_elements; + smem_shape[1] = shared_tile_width / swizzle_elements; + smem_shape[2] = shared_tile_height; + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(bf16); + + gmem_shape[0] = swizzle_elements; + gmem_shape[1] = global_tile_height; + gmem_shape[2] = global_tile_width / swizzle_elements; + + gmem_stride[0] = global_tile_width * sizeof(bf16); + gmem_stride[1] = ST::swizzle_bytes; + + smem_shape[0] = swizzle_elements; + smem_shape[1] = shared_tile_height; + smem_shape[2] = shared_tile_width / swizzle_elements; + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + gmem_shape[0] = 8; + gmem_shape[1] = 8; + gmem_shape[2] = global_tile_width/8; + gmem_shape[3] = global_tile_height/8; + + gmem_stride[0] = global_tile_width * sizeof(bf16); + gmem_stride[1] = 8 * sizeof(bf16); + gmem_stride[2] = 8 * global_tile_width * sizeof(bf16); + + smem_shape[0] = 8; + smem_shape[1] = 8; + smem_shape[2] = shared_tile_width/8; + smem_shape[3] = shared_tile_height/8; + } + + // ensure that the global address is always 16-byte aligned + assert((reinterpret_cast(global_addr) & 0b1111) == 0); + + assert(gmem_stride[0] % 16 == 0); // gmem_stride[0] elements must be a multiple of 16B + assert(gmem_stride[1] % 16 == 0); // gmem_stride[1] elements must be a multiple of 16B + assert(gmem_stride[2] % 16 == 0); // gmem_stride[2] elements must be a multiple of 16B + + assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256 + assert(smem_shape[1] <= 256); // smem_shape[1] elements must be <= 256 + assert(smem_shape[2] <= 256); // smem_shape[2] elements must be <= 256 + assert(smem_shape[3] <= 256); // smem_shape[3] elements must be <= 256 + + assert(smem_shape[0] * sizeof(bf16) % 16 == 0); // if wgmma_interleave is none, then smem_shape[0] * sizeof(bf16) must be a multiple of 16B + + assert(smem_stride[0] <= 8); // smem_stride[0] must be less <= 8 + assert(smem_stride[1] <= 8); // smem_stride[1] must be less <= 8 + assert(smem_stride[2] <= 8); // smem_stride[2] must be less <= 8 + assert(smem_stride[3] <= 8); // smem_stride[3] must be less <= 8 + + assert(smem_stride[0] == 1); // smem_stride[0] is ignored when wgmma_interleave is none + + if constexpr (tma_interleave == CU_TENSOR_MAP_INTERLEAVE_NONE && tma_swizzle != CU_TENSOR_MAP_SWIZZLE_NONE) { + constexpr int swizzle_size = (ST::width) * 32; + assert(smem_shape[0] * sizeof(bf16) <= swizzle_size); + } + + const uint64_t *gmem_shape_ptr = &gmem_shape[0]; + const uint64_t *gmem_stride_ptr = &gmem_stride[0]; + const uint32_t *smem_shape_ptr = &smem_shape[0]; + const uint32_t *smem_stride_ptr = &smem_stride[0]; + + CUresult result = cuTensorMapEncodeTiled( + tma_map, + tma_format, + tma_dim, + global_addr, + gmem_shape_ptr, + gmem_stride_ptr, + smem_shape_ptr, + smem_stride_ptr, + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oobFill); + + + const char *error_string; + CUresult res = cuGetErrorString(result, &error_string); + if (result != CUDA_SUCCESS) { + std::cerr << "Error: " << error_string << std::endl; + } +} + +/** +* @brief Allocates on the GPU and initializes a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the ST template parameter. +* +* @tparam ST The source tensor type, which must be TMA-compatible. +* @tparam blocks_height The number of tiles present on the height axis in global memory. +* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1. +* @param src Pointer to the source tensor data in global memory. +* @returns Pointer to the CUtensorMap object to be initialized. +*/ +template +__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const bf16 *src, int blocks_height, int blocks_width=1) { + CUtensorMap *tma_map_d; + cudaMalloc(&tma_map_d, sizeof(CUtensorMap)); + CUtensorMap tma_map_host; // put it on the stack, why not. + create_tensor_map(&tma_map_host, src, blocks_height, blocks_width); + cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + return tma_map_d; +} + +/* ---------- Prefetch Tensor Map ---------- */ + +/** + * @brief Prefetches data from global memory into a shared memory tile, along with the tensormap. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination shared memory tile. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] tile_row_idx The row index of the requested tile. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the requested tile. This is in units of complete tiles. + */ +template +__device__ static inline void prefetch(ST &dst, void const* const src_tma_map, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid()) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (dst.cols); + int32_t crd1 = tile_row_idx * (dst.rows); + + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile" + " [%0, {%1, %2}];" + : + : "l"(tma_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (dst.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (dst.rows); + + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.tile" + " [%0, {%1, %2, %3}];" + : + : "l"(tma_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (dst.rows); + int32_t crd2 = tile_col_idx * (dst.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.tile" + " [%0, {%1, %2, %3}];" + : + : "l"(tma_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (dst.cols/8); + int32_t crd3 = tile_row_idx * (dst.rows/8); + + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.tile" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(tma_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +/* ---------- Async load and store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously stores data into global memory from a shared memory tile. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row index of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the tile destination. This is in units of complete tiles. + */ +template +__device__ static inline void store_async(void *dst_tma_map, const ST &src, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (src.cols); + int32_t crd1 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + else if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (src.rows); + int32_t crd2 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (src.cols/8); + int32_t crd3 = tile_row_idx * (src.rows/8); + + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +/* ---------- Async reduction + store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously performs an add reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous add reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row index of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the tile destination. This is in units of complete tiles. + */ +template +__device__ static inline void store_add_async(void *dst_tma_map, const ST &src, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (src.cols); + int32_t crd1 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + else if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (src.rows); + int32_t crd2 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (src.cols/8); + int32_t crd3 = tile_row_idx * (src.rows/8); + + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +/** + * @brief Asynchronously performs an min reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous min reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row index of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the tile destination. This is in units of complete tiles. + */ +template +__device__ static inline void store_min_async(void *dst_tma_map, const ST &src, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (src.cols); + int32_t crd1 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + else if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (src.rows); + int32_t crd2 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (src.cols/8); + int32_t crd3 = tile_row_idx * (src.rows/8); + + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +/** + * @brief Asynchronously performs an max reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous max reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row index of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the tile destination. This is in units of complete tiles. + */ +template +__device__ static inline void store_max_async(void *dst_tma_map, const ST &src, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (src.cols); + int32_t crd1 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + else if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (src.rows); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (src.rows); + int32_t crd2 = tile_col_idx * (src.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (src.cols/8); + int32_t crd3 = tile_row_idx * (src.rows/8); + + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +/** + * @brief Asynchronously loads data from global memory into a shared memory tile. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination shared memory tile. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in,out] bar The barrier used for synchronization of the asynchronous copy. + * @param[in] tile_row_idx The row index of the requested tile. This is in units of complete tiles. + * @param[in] tile_col_idx The column index of the requested tile. This is in units of complete tiles. + */ +template +__device__ static inline void load_async(ST &dst, void const* const src_tma_map, barrier& bar, int tile_row_idx, int tile_col_idx=0) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + + if constexpr (detail::st_type_naive_layout) { + int32_t crd0 = tile_col_idx * (dst.cols); + int32_t crd1 = tile_row_idx * (dst.rows); + + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(crd0), "r"(crd1) + : "memory" + ); + } + else if constexpr (detail::st_type_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_col_idx * (dst.cols / (ST::swizzle_bytes / sizeof(bf16))); + int32_t crd2 = tile_row_idx * (dst.rows); + + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_swizzle_layout) { + int32_t crd0 = 0; + int32_t crd1 = tile_row_idx * (dst.rows); + int32_t crd2 = tile_col_idx * (dst.cols / (ST::swizzle_bytes / sizeof(bf16))); + + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory" + ); + } + else if constexpr (detail::st_type_wgmma_interleave_layout) { + int32_t crd0 = 0; + int32_t crd1 = 0; + int32_t crd2 = tile_col_idx * (dst.cols/8); + int32_t crd3 = tile_row_idx * (dst.rows/8); + + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory" + ); + } + } +} + +} // namespace tma +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/util/dsmem.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/util/dsmem.cuh new file mode 100644 index 0000000..22bc5a5 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/util/dsmem.cuh @@ -0,0 +1,136 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/shared/shared.cuh" + +namespace kittens { +namespace dsmem { + +using barrier = uint64_t; + +/** + * @brief Waits at a dsmem barrier until the memory and sufficient threads have arrived. + * + * This function is used to synchronize threads at a barrier. Each thread waits at the barrier + * until the local memory has arrived. + * + * @param bar Reference to the barrier variable. + * @param kPhaseBit The phase bit used for the barrier. + */ +__device__ static inline void arrive_and_wait(barrier& bar, int kPhaseBit) { + void const* const ptr = &bar; + uint32_t mbarrier_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: + "r"(mbarrier_ptr), + "r"(kPhaseBit) + ); +} + +/** + * @brief Sets the number of bytes expected at the barrier. + * + * This function is called by the first thread in the warp (laneid() == 0) to set the number of bytes + * expected at the barrier. It converts the barrier pointer to a generic shared memory pointer and + * uses inline assembly to set the expected number of bytes. + * + * @param bar Reference to the barrier variable. + * @param bytes The number of bytes expected at the barrier. + */ +__device__ static inline void set_bytes(barrier& bar, uint32_t bytes) { + if (laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "r"(bytes) + ); + + } +} + + +/** + * @brief Initialize a distribute shared memory barrier + * + * If the template arguments are left blank, the user is expected to call set_bytes manually. + * Alternatively, if a shared tile or shared vector type is passed, along with optional array + * dimensions, the barrier will be automatically initialized with the correct transaction size, too. + * + * @tparam T the type of the shared memory object being passed. Defaults to kittens::ducks::default_type. + * @tparam dims... Dimensions of the multidimensional array, if an array is being transferred. If blank, a single object is transferred. + * @param[out] bar Reference to the barrier variable. + * @param[in] tc The number of arriving threads the barrier should also wait for. + */ +template +__device__ static inline void init_barrier(barrier& bar, int tc=1) { + if (laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(tc) + ); + } + // Now initialize the bar bytes + if constexpr (ducks::st::all || ducks::sv::all) { + set_bytes(bar, kittens::size_bytes); + } +} + +// Generic transfer +template +__device__ static inline void distribute(T &dst_, T &src_, int cluster_size, int dst_idx, uint32_t size_bytes, barrier& bar) { + if (laneid() == 0) { + void const* const ptr = &bar; + uint32_t mbarrier_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + // ************************************************** + // load from src to dst in different threadblocks + auto src = &src_; + auto dst = &dst_; + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + uint32_t neighbor_rank = dst_idx; + + // mapa instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mapa + // find dst addr in neighbor's cta + uint32_t neighbor_addr_dst = dst_ptr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_dst) + : "r"(dst_ptr), "r"(neighbor_rank) + ); + + uint32_t neighbor_addr_mbarrier = mbarrier_ptr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_mbarrier) + : "r"(mbarrier_ptr), "r"(neighbor_rank) + ); + + // cp.async instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk + // copy src into dst in neighbor's cta + asm volatile ( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), "r"(neighbor_addr_mbarrier) + : "memory" + ); + } +} + +} +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/util/tma.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/util/tma.cuh new file mode 100644 index 0000000..6f19f34 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/util/tma.cuh @@ -0,0 +1,139 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +#include +#include + +namespace kittens { +/** + * @brief A namespace for all of ThunderKittens' TMA functionality. +*/ +namespace tma { + +namespace detail { + +// Concepts for tiles +template concept st_type_naive_layout = ( + std::is_same_v +); +template concept st_type_swizzle_layout = ( + std::is_same_v +); +template concept st_type_wgmma_swizzle_layout = ( + std::is_same_v +); +template concept st_type_wgmma_interleave_layout = ( + std::is_same_v +); + +} + +using barrier = uint64_t; + +/* ---------- Barrier functions for async load ---------- */ + +/** +* @brief Sets the number of bytes expected at the barrier. +* +* This function sets the number of bytes expected at the barrier for the first thread in the warp. +* It converts the barrier pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* @param barrier Reference to the barrier variable. +* @param bytes The number of bytes expected at the barrier. +*/ +__device__ static inline void set_bytes(barrier& bar, uint32_t bytes) { + if (::kittens::laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "r"(bytes)); + } +} +/** + * @brief Initializes a synchronization barrier with a transaction count and sets the expected number of bytes. + * + * This function sets up a barrier that is used to synchronize threads within a block during asynchronous operations. + * It initializes the barrier with a thread count barrier. + * + * Additionally, if it is given a shared tile type, it will also call `set_bytes` to prepare for the memory transaction. + * + * @param[out] barrier The barrier variable to initialize. + * @param[in] tc The thread counter for the barrier. + */ +template +__device__ static inline void init_barrier(barrier& bar, int tc=1) { + static_assert(ducks::st::all || ducks::sv::all || std::is_same_v); + if (::kittens::laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(tc)); + + if constexpr (ducks::st::all || ducks::sv::all) { + set_bytes(bar, kittens::size_bytes); // set barrier bytes automatically + } + } +} + +/** +* @brief Arrives at the barrier and waits for all threads to arrive. +* +* This function is used to synchronize threads at a barrier. Each thread arrives at the barrier +* and waits until all threads have arrived. The function uses inline assembly to perform the +* barrier wait operation. +* +* @param barrier Reference to the barrier variable. +* @param kPhaseBit The phase bit used for the barrier. +*/ +__device__ static inline void arrive_and_wait(barrier& bar, int kPhaseBit) { + void const* const ptr = &bar; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + + +/* ---------- Synchronization functions for async store ---------- */ + +/** + * @brief Commits previous asynchronous TMA stores to a group and performs them. +*/ +__device__ static inline void store_commit_group() { + if (::kittens::laneid() == 0) { + asm volatile("cp.async.bulk.commit_group;"); + } +} +/** + * @brief Waits for previous committed TMA store groups to complete. + * + * @tparam N The maximum number of remaining TMA store groups. Defaults to 0. +*/ +template +__device__ static inline void store_async_wait() { + asm volatile ( + "cp.async.bulk.wait_group %0;" + : + : "n"(N) + : "memory" + ); + __syncwarp(); +} + +} // namespace tma +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/util/util.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/util/util.cuh new file mode 100644 index 0000000..1704c7c --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/util/util.cuh @@ -0,0 +1,34 @@ +/** + * @file + * @brief General memory utilities not specialized for either tiles or vectors. + */ + +#pragma once + +namespace kittens { + +// template magic allows arrays of these objects to be copied, too. +namespace detail { +template struct size_info; +template struct size_info { + static constexpr uint32_t elements = ST::num_elements; + static constexpr uint32_t bytes = ST::num_elements * sizeof(typename ST::dtype); +}; +template struct size_info { + static constexpr uint32_t elements = SV::length; + static constexpr uint32_t bytes = SV::length * sizeof(typename SV::dtype); +}; +template struct size_info { + static constexpr uint32_t elements = dim*size_info::elements; + static constexpr uint32_t bytes = dim*size_info::bytes; +}; +} +template constexpr uint32_t size_elements = detail::size_info::elements; +template constexpr uint32_t size_bytes = detail::size_info::bytes; + +} + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#include "dsmem.cuh" +#endif \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/dsmem.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/dsmem.cuh new file mode 100644 index 0000000..d023222 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/dsmem.cuh @@ -0,0 +1,31 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/shared/shared.cuh" +#include "../util/util.cuh" + +namespace kittens { +namespace dsmem { + +/** + * @brief Distributes data from a source shared vector to a destination shared vector across different thread blocks. + * + * This function wraps the distribute function by automatically calculating the number of bytes to be transferred + * based on the shared vector type and optional dimensions provided. It facilitates the distribution of data across + * different clusters or thread blocks in a device. + * + * @tparam SV The shared vector type. + * @tparam dims Variadic template parameter representing the dimensions of the array of shared vectors to be distributed. + * @param[in,out] dst_ Reference to the destination shared vector. + * @param[in,out] src_ Reference to the source shared vector. + * @param[in] cluster_size The size of the cluster or the number of thread blocks involved in the distribution. + * @param[in] dst_idx The index of the destination thread block within the cluster. + * @param[in,out] bar Reference to a barrier used for synchronization across thread blocks. + */ +template +__device__ static inline void distribute(SV &dst_, SV &src_, int cluster_size, int dst_idx, barrier& bar) { + distribute(dst_, src_, cluster_size, dst_idx, kittens::size_bytes, bar); // wrap with auto calculated bytes +} + +} +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_register.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_register.cuh new file mode 100644 index 0000000..0f2a32c --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_register.cuh @@ -0,0 +1,120 @@ +/** + * @file + * @brief Functions for transferring data directly between global memory and registers and back. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Load data into a register vector from a source array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the source array. + * @param[out] dst The destination register vector to load data into. + * @param[in] src The source array in global memory to load data from. + */ +template +__device__ inline static void load(RV &dst, const U *src) { + using T2 = RV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + int laneid = ::kittens::laneid(); + + __syncwarp(); + if constexpr (dst.inner_dim == 2) { + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) + dst[o_dim][i_dim] = base_types::convertor::convert(*(U2*)&src[idx]); + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns + dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader); + dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4); + } + } + else { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) { + T tmp = base_types::convertor::convert(src[idx]); + if(laneid%2==0) dst[o_dim][0].x = tmp; + else dst[o_dim][0].y = tmp; + } + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns + dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader); + dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1); + } + } +} + +/** + * @brief Store data from a register vector to a destination array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register vector to store data from. + */ +template +__device__ inline static void store(U *dst, const RV &src) { + using T2 = RV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + int laneid = ::kittens::laneid(); + + __syncwarp(); + if constexpr (src.inner_dim == 2) { + #pragma unroll + for(auto w = 0; w < (src.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced store. I hope! + if(idx < src.outer_dim*16) + *(U2*)&dst[idx] = base_types::convertor::convert(src[o_dim][i_dim]); + } + } + else { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (src.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < src.outer_dim*16) { + U tmp; + if(laneid%2==0) tmp = base_types::convertor::convert(src[o_dim][0].x); + else tmp = base_types::convertor::convert(src[o_dim][0].y); + dst[idx] = tmp; + } + } + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_shared.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_shared.cuh new file mode 100644 index 0000000..ef38f86 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/global_to_shared.cuh @@ -0,0 +1,52 @@ +/** + * @file + * @brief Functions for transferring data directly between global and shared memory and back. + */ + +#pragma once + +#include + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Loads data from global memory into a shared memory vector. + * + * @tparam ST The shared memory vector type. + * @param[out] dst The destination shared memory vector. + * @param[in] src The source global memory array. + */ +template +__device__ static inline void load(SV &dst, const typename SV::dtype *src) { + constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr int total_calls = dst.length / elem_per_transfer; // guaranteed to divide + __syncwarp(); + #pragma unroll + for(int i = ::kittens::laneid(); i < total_calls; i+=WARP_THREADS) { + if(i * elem_per_transfer < dst.length) + *(float4*)&dst[i*elem_per_transfer] = *(float4*)&src[i*elem_per_transfer]; + } +} +/** + * @brief Stores data from a shared memory vector into global memory. + * + * @tparam ST The shared memory vector type. + * @param[out] dst The destination global memory array. + * @param[in] src The source shared memory vector. + */ +template +__device__ static inline void store(typename SV::dtype *dst, const SV &src) { + constexpr int elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr int total_calls = src.length / elem_per_transfer; // guaranteed to divide + __syncwarp(); + #pragma unroll + for(int i = ::kittens::laneid(); i < total_calls; i+=WARP_THREADS) { + if(i * elem_per_transfer < src.length) + *(float4*)&dst[i*elem_per_transfer] = *(float4*)&src[i*elem_per_transfer]; // lmao it's identical + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/shared_to_register.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/shared_to_register.cuh new file mode 100644 index 0000000..20eb311 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/shared_to_register.cuh @@ -0,0 +1,128 @@ +/** + * @file + * @brief Functions for transferring data directly between shared memory and registers and back. + */ + +#pragma once + +#include + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Load data from a shared vector into a register vector. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination register vector. + * @param src[in] The source shared vector. + */ +template +__device__ inline static void load(RV &dst, const SV &src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + static_assert(src.tiles == dst.outer_dim); + + int laneid = ::kittens::laneid(); + + __syncwarp(); + if constexpr (dst.inner_dim == 2) { + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) + dst[o_dim][i_dim] = base_types::convertor::convert(*(U2*)&src[idx]); + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns + dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader); + dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4); + } + } + else { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) { + T tmp = base_types::convertor::convert(src[idx]); + if(laneid%2==0) dst[o_dim][0].x = tmp; + else dst[o_dim][0].y = tmp; + } + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns + dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader); + dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1); + } + } +} + +/** + * @brief Store data into a shared vector from a register vector. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination shared vector. + * @param src[in] The source register vector. + */ +template +__device__ inline static void store(SV &dst, const RV &src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + static_assert(dst.tiles == src.outer_dim); + + int laneid = ::kittens::laneid(); + + __syncwarp(); + if constexpr (src.inner_dim == 2) { + #pragma unroll + for(auto w = 0; w < (src.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced store. I hope! + if(idx < src.outer_dim*16) + *(U2*)&dst[idx] = base_types::convertor::convert(src[o_dim][i_dim]); + } + } + else { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (src.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < src.outer_dim*16) { + U tmp; + if(laneid%2==0) tmp = base_types::convertor::convert(src[o_dim][0].x); + else tmp = base_types::convertor::convert(src[o_dim][0].y); + dst[idx] = tmp; + } + } + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/tma.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/tma.cuh new file mode 100644 index 0000000..343565d --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/tma.cuh @@ -0,0 +1,270 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" +#include "../util/util.cuh" + +#include +#include + +namespace kittens { +namespace tma { + +/* ---------- Create tensor map descriptor (HOST) ---------- */ + +/** +* @brief Creates a tensor map for the given source vector. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the SV template parameter. +* +* @tparam SV The source tensor type, which must be TMA-compatible. +* @tparam num_vectors The number of vectors present in global memory. +* @param tma_map Pointer to the CUtensorMap object to be initialized. +* @param src Pointer to the source tensor data in global memory. +*/ +template +__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const bf16 *src, int num_vectors) { + + constexpr uint32_t tma_dim = 1; + void *global_addr = (void*)(src); + + constexpr CUtensorMapDataType tma_format = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + constexpr CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + + uint64_t gmem_shape [1] = {SV::length * num_vectors}; + uint64_t gmem_stride[1] = {1}; + uint32_t smem_shape [1] = {SV::length}; + uint32_t smem_stride[1] = {1}; + + // ensure that the global address is always 16-byte aligned + assert((reinterpret_cast(global_addr) & 0b1111) == 0); + + assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256 + + const uint64_t *gmem_shape_ptr = &gmem_shape[0]; + const uint64_t *gmem_stride_ptr = &gmem_stride[0]; + const uint32_t *smem_shape_ptr = &smem_shape[0]; + const uint32_t *smem_stride_ptr = &smem_stride[0]; + + CUresult result = cuTensorMapEncodeTiled( + tma_map, + tma_format, + tma_dim, + global_addr, + gmem_shape_ptr, + gmem_stride_ptr, + smem_shape_ptr, + smem_stride_ptr, + tma_interleave, + swizzle, + tma_l2Promotion, + tma_oobFill + ); + + const char *error_string; + CUresult res = cuGetErrorString(result, &error_string); + if (result != CUDA_SUCCESS) { + std::cerr << "Error: " << error_string << std::endl; + } +}; + +/** +* @brief Allocates on the GPU and initializes a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the SV template parameter. +* +* @tparam SV The source tensor type, which must be TMA-compatible. +* @tparam num_vectors The number of vectors present in global memory. +* @param src Pointer to the source tensor data in global memory. +* @returns Pointer to the CUtensorMap object to be initialized. +*/ +template +__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const bf16 *src, int num_vectors) { + CUtensorMap *tma_map_d; + cudaMalloc(&tma_map_d, sizeof(CUtensorMap)); + CUtensorMap tma_map_host; // put it on the stack, why not. + create_tensor_map(&tma_map_host, src, num_vectors); + cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + return tma_map_d; +} + +/* ---------- Prefetch Tensor Map ---------- */ + +/** + * @brief Prefetches data from global memory into a shared memory vector, along with the tensormap. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst The destination shared memory vector. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] vec_idx The index of the requested vector. + */ +template +__device__ static inline void prefetch(SV &dst, void const* const src_tma_map, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + + int32_t crd0 = vec_idx * (dst.length); + + asm volatile ( + "cp.async.bulk.prefetch.tensor.1d.L2.global.tile" + " [%0, {%1}];" + : + : "l"(tma_ptr), "r"(crd0) + : "memory" + ); + } +} + +/* ---------- Async load and store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously stores data into global memory from a shared memory vector. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst_tma_map The destination tensormap address in global memory + * @param[in] src The source shared memory vector. + * @param[in] vec_idx The index of the vector destination. + */ +template +__device__ static inline void store_async(void *dst_tma_map, const SV &src, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + int32_t crd0 = vec_idx * (src.length); + + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.tile.bulk_group" + " [%0, {%2}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), "r"(crd0) + : "memory" + ); + } +} + +/** +* @brief Asynchronously performs an add reduction and stores the result into global memory. +* +* This function performs an asynchronous add reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The index of the vector destination. +*/ +template +__device__ static inline void store_add_async(void *dst_tma_map, const SV &src, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + int32_t crd0 = vec_idx * (src.length); + + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), "r"(crd0) + : "memory" + ); + } +} + +/** +* @brief Asynchronously performs an min reduction and stores the result into global memory. +* +* This function performs an asynchronous min reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The index of the vector destination. +*/ +template +__device__ static inline void store_min_async(void *dst_tma_map, const SV &src, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + int32_t crd0 = vec_idx * (src.length); + + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), "r"(crd0) + : "memory" + ); + } +} + +/** +* @brief Asynchronously performs an max reduction and stores the result into global memory. +* +* This function performs an asynchronous max reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The index of the vector destination. +*/ +template +__device__ static inline void store_max_async(void *dst_tma_map, const SV &src, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + + int32_t crd0 = vec_idx * (src.length); + + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), "r"(crd0) + : "memory" + ); + } +} + +/** + * @brief Asynchronously loads data from global memory into a shared memory vector. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst The destination shared memory vector. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] vec_idx The index of the requested vector. + * @param[in,out] bar The barrier used for synchronization of the asynchronous copy. + */ +template +__device__ static inline void load_async(SV &dst, void const* const src_tma_map, barrier& bar, int vec_idx) { + if (::kittens::laneid() == 0) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + + int32_t crd0 = vec_idx * (dst.length); + + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), "r"(crd0) + : "memory" + ); + } +} + +} // namespace tma +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/memory/vec/vec.cuh b/triteia/csrc/flash_kittens/ops/warp/memory/vec/vec.cuh new file mode 100644 index 0000000..fcfdc79 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/memory/vec/vec.cuh @@ -0,0 +1,15 @@ +/** + * @file + * @brief An aggregate header of warp memory operations on vectors, where a single warp loads or stores data on its own. + */ + +#pragma once + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#include "dsmem.cuh" +#endif \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/register.cuh b/triteia/csrc/flash_kittens/ops/warp/register/register.cuh new file mode 100644 index 0000000..a802538 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/register.cuh @@ -0,0 +1,9 @@ +/** + * @file + * @brief An aggregate header for warp operations on data stored in registers. + */ + +#pragma once + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/tile/conversions.cuh b/triteia/csrc/flash_kittens/ops/warp/register/tile/conversions.cuh new file mode 100644 index 0000000..098b6ce --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/tile/conversions.cuh @@ -0,0 +1,323 @@ +/** + * @file + * @brief Conversions between data layouts and types for register tiles. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- LAYOUT SWAPS ---------- */ + +/** + * @brief Perform a matrix transpose on a block of 8 bf16_2 elements using inline assembly. + * + * This low-level operation is utilized by higher-level layout swap functions to transpose + * the layout of bf16_2 elements within a register tile. The function leverages inline PTX + * assembly to efficiently swap the layout of the given block. + * + * @param[out] dst A reference to the destination bf16_2 element where the transposed result is stored. + * @param[in] src A reference to the source bf16_2 element to be transposed. + */ +__device__ inline void swap_layout_8(bf16_2 &dst, const bf16_2 &src) { + asm volatile ( + "movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "+r"(*(uint32_t*)(&dst)) + : "r"(*(uint32_t*)(&src)) + ); +} +/** + * @brief Swaps the layout of a register base tile. + * + * This function swaps the layout of a register base tile by performing a series of layout swaps + * on its constituent bf16_2 elements. It is used to change the data layout within a register tile. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the destination register base tile where the result will be stored. + * @param src[in] Reference to the source register base tile to be swapped. + */ +template +__device__ inline void swap_layout(rt_base::type> &dst, const rt_base &src) { + swap_layout_8(dst.data[0], src.data[0]); + // technically this swap can be eliminated if we simply reinterpret the layout of the registers + // everywhere else in the code, but that feels... very likely to cause bugs and not worth it. + T2 data1_cache = src.data[1]; // important for swap! + swap_layout_8(dst.data[1], src.data[2]); + swap_layout_8(dst.data[2], data1_cache); + swap_layout_8(dst.data[3], src.data[3]); +} +/** + * @brief Swaps the layout of a register tile. + * + * This function swaps the layout of a register tile by iterating over its height and width + * and performing layout swaps on each of its base elements. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the destination register tile where the result will be stored. + * @param src[in] Reference to the source register tile to be swapped. + */ +template +__device__ static inline void swap_layout(rt::type> &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + swap_layout(dst.tiles[i][j], src.tiles[i][j]); + } + } +} + +/** + * @brief Swaps the layout of a register base tile in place. + * + * This function swaps the layout of a register base tile in place by casting it to the + * transposed layout type and then performing the layout swap. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param src[in] Reference to the register base tile to be swapped in place. + * @return A reference to the swapped register base tile. + */ +template +__device__ inline rt_base::type>& swap_layout_inplace(const rt_base &src) { + rt_base::type> &dst = *(rt_base::type>*)(&src); + swap_layout(dst, src); + return dst; +} +/** + * @brief Swaps the layout of a register tile in place. + * + * This function swaps the layout of a register tile in place by iterating over its height and width + * and performing in-place layout swaps on each of its base elements. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param tile[in,out] Reference to the register tile to be swapped in place. + * @return A reference to the swapped register tile. + */ +template +__device__ static inline rt::type>& swap_layout_inplace(rt &tile) { + #pragma unroll + for(int i = 0; i < tile.height; i++) { + #pragma unroll + for(int j = 0; j < tile.width; j++) { + swap_layout_inplace(tile.tiles[i][j]); + } + } + return *(rt::type>*)(&tile); +} + +/* ---------- TRANSPOSE ---------- */ + +/** + * @brief Transposes a register base tile. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the register tile in which to store the transposed src. + * @param src[in] Reference to the register base tile to be transposed. + */ +template +__device__ inline void transpose(rt_base &dst, const rt_base &src) { + swap_layout_8(dst.data[0], src.data[0]); + // technically this swap can be eliminated if we simply reinterpret the layout of the registers + // everywhere else in the code, but that feels... very likely to cause bugs and not worth it. + T2 data1_cache = src.data[1]; // important for swap! + swap_layout_8(dst.data[1], src.data[2]); + swap_layout_8(dst.data[2], data1_cache); + swap_layout_8(dst.data[3], src.data[3]); +} +/** + * @brief Transposes a register tile. + * + * This function is marked "sep", which means that the registers underlying dst MUST be separate + * from the registers underlying src. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the src register tile, and the width of the dst tile. + * @tparam _width The width of the src register tile, and the height of the dst tile. + * @tparam layout The layout of the register tile. + * @param dst[out] Reference to the register tile in which to store the transposed src. + * @param src[in] Reference to the register tile to be transposed. + */ +template +__device__ static inline void transpose_sep(rt &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < _height; i++) { + #pragma unroll + for(int j = 0; j < _width; j++) { + transpose(dst.tiles[j][i], src.tiles[i][j]); + } + } +} + +/** + * @brief Transposes a register base tile in-place. + * + * @tparam T2 The data type of the register base tile elements. + * @tparam layout The current layout of the register base tile. + * @param src[in] Reference to the register tile to be transposed. + * @return A reference to the transposed register base tile. + */ +template +__device__ inline rt_base& transpose_inplace(rt_base &src) { + transpose(src, src); + return src; +} +/** + * @brief Transposes a square register tile in-place. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.) + * @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.) + * @tparam layout The current layout of the register tile. + * @param src[in] Reference to the register tile to be transposed. + * @return A reference to the transposed register tile. + */ +template +__device__ static inline rt& transpose_inplace(rt &tile) { + static_assert(_width == _height, "in-place register tile transpose is only allowed for square tiles."); + #pragma unroll + for(int i = 0; i < _height; i++) { + #pragma unroll + for(int j = 0; j < i; j++) { + rt_base tmp; + copy(tmp, tile.tiles[i][j]); + transpose(tile.tiles[i][j], tile.tiles[j][i]); + transpose(tile.tiles[j][i], tmp); + } + transpose_inplace(tile.tiles[i][i]); + } + return tile; +} + +/* ---------- TYPE SWAPS ---------- */ + +/** + * @brief Copies a register base tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam layout The current layout of the register base tile. + * @param[out] dst A reference to the destination register base tile. + * @param[in] src A reference to the source register base tile. + */ +template +__device__ static inline void copy(rt_base &dst, const rt_base &src) { + #pragma unroll + for(int k = 0; k < dst.packed_per_thread; k++) { + dst.data[k] = base_types::convertor::convert(src.data[k]); + } +} +/** + * @brief Copies a register tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam _height The height (in units of 16) of the register tiles. + * @tparam _width The width (in units of 16) of the register tiles. + * @tparam layout The current layout of the register tile. + * @param[out] dst A reference to the destination register tile. + * @param[in] src A reference to the source register tile. + */ +template +__device__ static inline void copy(rt &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + copy(dst.tiles[i][j], src.tiles[i][j]); + } + } +} + +/* ---------- CAUSAL ---------- */ + +/** + * @brief Makes a square register tile causal by zeroing elements above the main diagonal. + * + * This function modifies a square register tile in-place to make it causal. All elements + * above the main diagonal are set to zero, while elements on or below the main diagonal + * are left unchanged. + * + * @tparam T The data type of the register tile elements. + * @tparam _size The size (height and width) of the square register tile. + * @tparam layout The current layout of the register tile. + * @param tile[in,out] Reference to the register tile to be made causal. + */ +template +__device__ static inline void make_causal(RT &dst, const RT &src, const typename base_types::packing::unpacked_type &val=0) { + const typename RT::dtype packed_val = base_types::packing::pack(val); + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if(j < i) { // below the diagonal, copy + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = src.tiles[i][j].data[k]; + } + } + else if(j > i) { // above the diagonal, zero + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = packed_val; + } + } + else { // on the diagonal, interesting! + constexpr uint32_t MASK_X = 0xFF773311, MASK_Y = 0xF7733110; // magic numbers for on-diagonal core matrices + dst.tiles[i][j].data[1] = src.tiles[i][j].data[1]; // below diagonal, copy + dst.tiles[i][j].data[2] = packed_val; // above diagonal, zero + if((MASK_X >> laneid()) & 1) { + dst.tiles[i][j].data[0].x = src.tiles[i][j].data[0].x; + dst.tiles[i][j].data[3].x = src.tiles[i][j].data[3].x; + } + else { + dst.tiles[i][j].data[0].x = val; + dst.tiles[i][j].data[3].x = val; + } + if((MASK_Y >> laneid()) & 1) { + dst.tiles[i][j].data[0].y = src.tiles[i][j].data[0].y; + dst.tiles[i][j].data[3].y = src.tiles[i][j].data[3].y; + } + else { + dst.tiles[i][j].data[0].y = val; + dst.tiles[i][j].data[3].y = val; + } + } + } + } +} + + +/* ---------- SUBTILE ---------- */ + +/** +* @brief Returns a reference to a subtile of the given tile. +* +* @tparam subtile_height The height of the subtile. +* @tparam RT The type of the input tile, which must satisfy the ducks::rt::all concept. +* @param src The input tile. +* @param idx The index of the subtile. +* @return A reference to the subtile. +* +* @note The subtile height must evenly divide the tile height. +*/ +template +__device__ inline rt &subtile_inplace(RT & src, int idx) { + static_assert(RT::height % subtile_height == 0, "subtile height should evenly divide tile height."); + return reinterpret_cast&>( + src.tiles[idx*subtile_height] + ); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/tile/maps.cuh b/triteia/csrc/flash_kittens/ops/warp/register/tile/maps.cuh new file mode 100644 index 0000000..727df1e --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/tile/maps.cuh @@ -0,0 +1,682 @@ +/** + * @file + * @brief Map operations: between tiles, and those which apply vectors to tiles. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- Uniform tile maps (independent of layout) ---------- */ + +/** + * @brief Applies a unary operation to each element of a tile. + * + * @tparam op Unary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + */ +template +__device__ static inline void unary_map(T &dst, const T &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(src.tiles[i][j].data[k]); + } + } + } +} + +/** + * @brief Applies a binary operation to each element of a tile with a scalar parameter. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param param[in] Scalar parameter for the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(src.tiles[i][j].data[k], param); + } + } + } +} +/** + * @brief Applies a binary operation to each element of a tile with an unpacked scalar parameter. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param param[in] Unpacked scalar parameter for the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename base_types::packing::unpacked_type ¶m) { + // The optimizing compiler should eliminate this pack in the 32-bit case but not in the 16-bit case + bin_map(dst, src, base_types::packing::pack(param)); +} +/** + * @brief Applies a binary operation element-wise between two tiles. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile for the operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(lhs.tiles[i][j].data[k], rhs.tiles[i][j].data[k]); + } + } + } +} + +/* ---------- Row tile maps ----------*/ + +/** + * @brief Applies an operation across the rows of a tile in a row-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &row_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + dtype packed_top_row = base_types::packing::pack(row_values[i][0].x); // first value in eager mode + dtype packed_bottom_row = base_types::packing::pack(row_values[i][0].y); // second value in eager mode + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], packed_top_row); + dst.tiles[i][j].data[k+1] = op::template op(src.tiles[i][j].data[k+1], packed_bottom_row); + } + } + } +} +/** + * @brief Applies an operation across the rows of a tile in a column-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &row_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], row_values[i][0]); + dst.tiles[i][j].data[k+2] = op::template op(src.tiles[i][j].data[k+2], row_values[i][1]); + } + } + } +} + + +// Three-operand row map. Mostly useful for FMA instructions. + +/** + * @brief Applies an operation across the rows of two tiles in a row-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + dtype packed_top_row = base_types::packing::pack(row_values[i][0].x); // first value in eager mode + dtype packed_bottom_row = base_types::packing::pack(row_values[i][0].y); // second value in eager mode + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_top_row); + dst.tiles[i][j].data[k+1] = op::template op(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_bottom_row); + } + } + } +} +/** + * @brief Applies an operation across the rows of two tiles in a column-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], row_values[i][0]); + dst.tiles[i][j].data[k+2] = op::template op(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], row_values[i][1]); + } + } + } +} + +/* ---------- Col major tile maps ----------*/ + +/** + * @brief Applies an operation across the columns of a tile in a row-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &col_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], col_values[j][0]); + dst.tiles[i][j].data[k+2] = op::template op(src.tiles[i][j].data[k+2], col_values[j][1]); + } + } + } +} +/** + * @brief Applies an operation across the columns of a tile in a column-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &col_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + dtype packed_left_col = base_types::packing::pack(col_values[j][0].x); // first value in eager mode + dtype packed_right_col = base_types::packing::pack(col_values[j][0].y); // second value in eager mode + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], packed_left_col); + dst.tiles[i][j].data[k+1] = op::template op(src.tiles[i][j].data[k+1], packed_right_col); + } + } + } +} + +// Three-operand col map +/** + * @brief Applies an operation across the columns of two tiles in a row-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], col_values[j][0]); + dst.tiles[i][j].data[k+2] = op::template op(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], col_values[j][1]); + } + } + } +} +/** + * @brief Applies an operation across the columns of two tiles in a column-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + #pragma unroll + for(int j = 0; j < dst.width; j++) { + dtype packed_left_col = base_types::packing::pack(col_values[j][0].x); // first value in eager mode + dtype packed_right_col = base_types::packing::pack(col_values[j][0].y); // second value in eager mode + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_left_col); + dst.tiles[i][j].data[k+1] = op::template op(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_right_col); + } + } + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// All of the annoying qualifiers *should* be automatically inferred during compile-time. +// So, syntax should just be kittens::add_row(tile, colvec); + +/** + * @brief Sets all elements of a tile to zero. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void zero(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to one. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void one(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to positive infinity. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to negative infinity. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_map(dst, dst); +} + +/** + * @brief Applies the exponential function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the exponential function on. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the natural logarithm function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the natural logarithm function on. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the absolute value function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the absolute value function on. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the rectified linear unit (ReLU) function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the ReLU function on. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Copies the elements from one tile to another. + * + * @tparam T Destination tile type. + * @tparam U Source tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to copy from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_map(dst, src); +} + +/** + * @brief Applies the max operation element-wise between two tiles or a tile and a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile or scalar for the operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Applies the min operation element-wise between two tiles or a tile and a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile or scalar for the operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Adds two tiles element-wise or adds a scalar to each element of a tile. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the addition. + * @param rhs[in] Right-hand side source tile or scalar for the addition. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Subtracts two tiles element-wise or subtracts a scalar from each element of a tile. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the subtraction. + * @param rhs[in] Right-hand side source tile or scalar for the subtraction. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Multiplies two tiles element-wise or multiplies each element of a tile by a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the multiplication. + * @param rhs[in] Right-hand side source tile or scalar for the multiplication. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Divides two tiles element-wise or divides each element of a tile by a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the division. + * @param rhs[in] Right-hand side source tile or scalar for the division. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +/** + * @brief Adds row values to each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param row_values[in] Column vector containing values to add to each row. + */ +template +__device__ static inline void add_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Subtracts row values from each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param row_values[in] Column vector containing values to subtract from each row. + */ +template +__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Multiplies each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param row_values[in] Column vector containing values to multiply each row by. + */ +template +__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Divides each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param row_values[in] Column vector containing values to divide each row by. + */ +template +__device__ static inline void div_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Broadcast a vector into into a tile's rows. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Column vector containing values to broadcast into rows. + */ +template +__device__ static inline void broadcast_row(T &dst, const V &row_values) { + row_map(dst, dst, row_values); +} + + +// col maps +/** + * @brief Adds column values to each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param col_values[in] Row vector containing values to add to each column. + */ +template +__device__ static inline void add_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Subtracts column values from each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param col_values[in] Row vector containing values to subtract from each column. + */ +template +__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Multiplies each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param col_values[in] Row vector containing values to multiply each column by. + */ +template +__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Divides each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param col_values[in] Row vector containing values to divide each column by. + */ +template +__device__ static inline void div_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Broadcast a vector into into a tile's columns. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Row vector containing values to broadcast into cols. + */ +template +__device__ static inline void broadcast_col(T &dst, const V &col_values) { + col_map(dst, dst, col_values); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/tile/mma.cuh b/triteia/csrc/flash_kittens/ops/warp/register/tile/mma.cuh new file mode 100644 index 0000000..eff3fd7 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/tile/mma.cuh @@ -0,0 +1,448 @@ +/** + * @file + * @brief Matrix multiply-accumulate operations for tiles stored in registers. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Perform the HMMA.16816 operation. + * + * This function performs the half-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32` instruction. + * + * @param[out] d0 The first half of the output float2 accumulator. + * @param[out] d1 The second half of the output float2 accumulator. + * @param[in] a0 The first half of the first input bf16_2 matrix. + * @param[in] a1 The second half of the first input bf16_2 matrix. + * @param[in] a2 The first half of the second input bf16_2 matrix. + * @param[in] a3 The second half of the second input bf16_2 matrix. + * @param[in] b0 The first half of the bf16_2 matrix B. + * @param[in] b1 The second half of the bf16_2 matrix B. + * @param[in] c0 The first half of the float2 accumulator matrix C. + * @param[in] c1 The second half of the float2 accumulator matrix C. + */ +__device__ static inline void hmma16816( float2 &d0, float2 &d1, + const bf16_2 &a0, const bf16_2 &a1, const bf16_2 &a2, const bf16_2 &a3, + const bf16_2 &b0, const bf16_2 &b1, + const float2 &c0, const float2 &c1 ) { + asm volatile( + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(d0.x), "+f"(d0.y), + "+f"(d1.x), "+f"(d1.y) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "f"(c0.x), "f"(c0.y), + "f"(c1.x), "f"(c1.y) + ); +} +/** + * @brief Perform the HMMA.16816 operation. + * + * This function performs the half-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16` instruction. + * + * @param[out] d0 The first half of the output half_2 accumulator. + * @param[out] d1 The second half of the output half_2 accumulator. + * @param[in] a0 The first half of the first input half_2 matrix. + * @param[in] a1 The second half of the first input half_2 matrix. + * @param[in] a2 The first half of the second input half_2 matrix. + * @param[in] a3 The second half of the second input half_2 matrix. + * @param[in] b0 The first half of the half_2 matrix B. + * @param[in] b1 The second half of the half_2 matrix B. + * @param[in] c0 The first half of the half_2 accumulator matrix C. + * @param[in] c1 The second half of the half_2 accumulator matrix C. + */ +__device__ static inline void hmma16816( half_2 &d0, half_2 &d1, + const half_2 &a0, const half_2 &a1, const half_2 &a2, const half_2 &a3, + const half_2 &b0, const half_2 &b1, + const half_2 &c0, const half_2 &c1 ) { + asm volatile( + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ + "{%0, %1}, " \ + "{%2, %3, %4, %5}, " \ + "{%6, %7}, " \ + "{%8, %9};" + + // D matrix + : "=r"(*(uint32_t*)(&d0)), "=r"(*(uint32_t*)(&d1)) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "r"(*(uint32_t*)(&c0)), "r"(*(uint32_t*)(&c1)) + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base dot product operation for row layout. + * + * This function performs the base dot product operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in row-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_ABt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], // for some reason this one seems to need to be backwards + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], // for some reason this one seems to need to be backwards + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A and B. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtBt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} + +/** + * @brief Matrix multiply-accumulate operation. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` function. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_hf accumulator. + * @param[in] a The first input rt_hf matrix. + * @param[in] b The second input rt_hf matrix in column-major mode. + * @param[in] c The input rt_hf accumulator matrix. + */ +template +__device__ static inline void mma_AB(rt_hf &d, + const rt_hf &a, + const rt_hf &b, + const rt_hf &c) { + #pragma unroll + for(int n = 0; n < N; n++) { + #pragma unroll + for(int m = 0; m < M; m++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][0], + b.tiles[0][m], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < K; k++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][k], + b.tiles[k][m], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Matrix multiply-accumulate operation. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` function. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in column-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_AB(rt_fl &d, + const rt_bf &a, + const rt_bf &b, + const rt_fl &c) { + #pragma unroll + for(int n = 0; n < N; n++) { + #pragma unroll + for(int m = 0; m < M; m++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][0], + b.tiles[0][m], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < K; k++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][k], + b.tiles[k][m], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Dot product operation for row layout. + * + * This function performs the dot product operation + * using the `hmma16816` function. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in row-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_ABt(rt_fl &d, + const rt_bf &a, + const rt_bf &b, // notice row and (M, K) instead of col and (K, M) + const rt_fl &c) { + #pragma unroll + for(int n = 0; n < N; n++) { + #pragma unroll + for(int m = 0; m < M; m++) { + mma_ABt_base( + d.tiles[n][m], + a.tiles[n][0], + b.tiles[m][0], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < K; k++) { + mma_ABt_base( + d.tiles[n][m], + a.tiles[n][k], + b.tiles[m][k], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Matrix multiply-accumulate operation with transposed A. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` instruction. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in column-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_AtB(rt_fl &d, + const rt_bf &a, + const rt_bf &b, + const rt_fl &c) { + #pragma unroll + for(int n = 0; n < N; n++) { + #pragma unroll + for(int m = 0; m < M; m++) { + mma_AtB_base( + d.tiles[n][m], + a.tiles[0][n], + b.tiles[0][m], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < K; k++) { + mma_AtB_base( + d.tiles[n][m], + a.tiles[k][n], + b.tiles[k][m], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Matrix multiply-accumulate operation with transposed A and B. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` instruction. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in column-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_AtBt(rt_fl &d, + const rt_bf &a, + const rt_bf &b, + const rt_fl &c) { + #pragma unroll + for(int n = 0; n < N; n++) { + #pragma unroll + for(int m = 0; m < M; m++) { + mma_AtBt_base( + d.tiles[n][m], + a.tiles[0][n], + b.tiles[m][0], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < K; k++) { + mma_AtBt_base( + d.tiles[n][m], + a.tiles[k][n], + b.tiles[m][k], + d.tiles[n][m] + ); + } + } + } +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/tile/reductions.cuh b/triteia/csrc/flash_kittens/ops/warp/register/tile/reductions.cuh new file mode 100644 index 0000000..5366dbf --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/tile/reductions.cuh @@ -0,0 +1,455 @@ +/** + * @file + * @brief Reduction operations mapping tiles to vectors. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Perform a row-wise reduction on a matrix in row-major layout. + * + * This function template performs a parallel reduction across the rows of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type with row layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x1C; // 11100 in binary + #pragma unroll + for(int i = 0; i < src.height; i++) { + dtype accum_top_row = op::template op(src.tiles[i][0].data[0], src.tiles[i][0].data[2]); + dtype accum_bottom_row = op::template op(src.tiles[i][0].data[1], src.tiles[i][0].data[3]); + #pragma unroll + for(int j = 1; j < src.width; j++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile; k+=2) { + accum_top_row = op::template op(accum_top_row, src.tiles[i][j].data[k+0]); + accum_bottom_row = op::template op(accum_bottom_row, src.tiles[i][j].data[k+1]); + } + } + dtype accum_packed; + accum_packed.x = op::template op::unpacked_type>(accum_top_row.x, accum_top_row.y); + accum_packed.y = op::template op::unpacked_type>(accum_bottom_row.x, accum_bottom_row.y); + + // Now we need to do a lil shuffle to make everyone happy. + + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2)); + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1)); + + accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader); + + if(reset) { + row_accum[i][0] = accum_packed; + } + else { + row_accum[i][0] = op::template op(src_accum[i][0], accum_packed); + } + } +} +/** + * @brief Perform a row-wise reduction on a matrix in column-major layout. + * + * This function template performs a parallel reduction across the rows of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type with column layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::col_vec_pack); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x3; // 00011 in binary + #pragma unroll + for(int i = 0; i < src.height; i++) { + dtype accum_top_rows = op::template op(src.tiles[i][0].data[0], src.tiles[i][0].data[1]); + dtype accum_bottom_rows = op::template op(src.tiles[i][0].data[2], src.tiles[i][0].data[3]); + #pragma unroll + for(int j = 1; j < src.width; j++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile/2; k++) { + accum_top_rows = op::template op(accum_top_rows, src.tiles[i][j].data[k+0]); + accum_bottom_rows = op::template op(accum_bottom_rows, src.tiles[i][j].data[k+2]); + } + } + + // Now we need to do a lil shuffle to make everyone happy. + + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 16)); + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 8)); + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 4)); + + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 16)); + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 8)); + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 4)); + + accum_top_rows = packed_shfl_sync(MASK_ALL, accum_top_rows, leader); + accum_bottom_rows = packed_shfl_sync(MASK_ALL, accum_bottom_rows, leader); + + if(reset) { + row_accum[i][0] = accum_top_rows; + row_accum[i][1] = accum_bottom_rows; + } + else { + row_accum[i][0] = op::template op(src_accum[i][0], accum_top_rows); + row_accum[i][1] = op::template op(src_accum[i][1], accum_bottom_rows); + } + } +} + +// Col reduction. +/** + * @brief Perform a column-wise reduction on a matrix in row-major layout. + * + * This function template performs a parallel reduction across the columns of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for row-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the column accumulator. + * @tparam T The matrix type with row layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x3; // 00011 in binary + #pragma unroll + for(int j = 0; j < src.width; j++) { + dtype accum_left_cols = op::template op(src.tiles[0][j].data[0], src.tiles[0][j].data[1]); + dtype accum_right_cols = op::template op(src.tiles[0][j].data[2], src.tiles[0][j].data[3]); + #pragma unroll + for(int i = 1; i < src.height; i++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile/2; k++) { + accum_left_cols = op::template op(accum_left_cols, src.tiles[i][j].data[k+0]); + accum_right_cols = op::template op(accum_right_cols, src.tiles[i][j].data[k+2]); + } + } + + // Now we need to do a lil shuffle to make everyone happy. + + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 16)); + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 8)); + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 4)); + + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 16)); + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 8)); + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 4)); + + accum_left_cols = packed_shfl_sync(MASK_ALL, accum_left_cols, leader); + accum_right_cols = packed_shfl_sync(MASK_ALL, accum_right_cols, leader); + + if(reset) { + col_accum[j][0] = accum_left_cols; + col_accum[j][1] = accum_right_cols; + } + else { + col_accum[j][0] = op::template op(src_accum[j][0], accum_left_cols); + col_accum[j][1] = op::template op(src_accum[j][1], accum_right_cols); + } + } +} +/** + * @brief Perform a column-wise reduction on a matrix in column-major layout. + * + * This function template performs a parallel reduction across the columns of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the column accumulator. + * @tparam T The matrix type with column layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v); // compatible type + static_assert(V::inner_dim == rt_base::row_vec_pack); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = V::dtype; + const int leader = threadIdx.x & 0x1C; // 11100 in binary + #pragma unroll + for(int j = 0; j < src.width; j++) { // note now width is the outer loop + dtype accum_left_col = op::template op(src.tiles[0][j].data[0], src.tiles[0][j].data[2]); + dtype accum_right_col = op::template op(src.tiles[0][j].data[1], src.tiles[0][j].data[3]); + #pragma unroll + for(int i = 1; i < src.height; i++) { // and height is the inner loop + #pragma unroll + for(int k = 0; k < src.packed_per_tile; k+=2) { + accum_left_col = op::template op(accum_left_col, src.tiles[i][j].data[k+0]); + accum_right_col = op::template op(accum_right_col, src.tiles[i][j].data[k+1]); + } + } + dtype accum_packed; + accum_packed.x = op::template op::unpacked_type>(accum_left_col.x, accum_left_col.y); + accum_packed.y = op::template op::unpacked_type>(accum_right_col.x, accum_right_col.y); + + // Now we need to do a lil shuffle to make everyone happy. + + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2)); + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1)); + + accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader); + + if(reset) { + col_accum[j][0] = accum_packed; + } + else { + col_accum[j][0] = op::template op(src_accum[j][0], accum_packed); + } + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// two-operand row reductions. (Accumulate and REPLACE.) +/** + * @brief Store the maximum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the minimum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the sum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the product of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +// three-operand row reductions. (Accumulate ONTO.) +/** + * @brief Store the maximum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the minimum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the sum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the product of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} + +// two-operand col reductions. (Accumulate and REPLACE.) + +/** + * @brief Store the maximum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the minimum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the sum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the product of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +// three-operand col reductions. (Accumulate ONTO.) +/** + * @brief Store the maximum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the minimum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the sum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the product of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/tile/tile.cuh b/triteia/csrc/flash_kittens/ops/warp/register/tile/tile.cuh new file mode 100644 index 0000000..f154dab --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/tile/tile.cuh @@ -0,0 +1,11 @@ +/** + * @file + * @brief An aggregate header for warp operations on register tiles. + */ + +#pragma once + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" +#include "mma.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/vec/conversions.cuh b/triteia/csrc/flash_kittens/ops/warp/register/vec/conversions.cuh new file mode 100644 index 0000000..4f6f564 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/vec/conversions.cuh @@ -0,0 +1,104 @@ +/** + * @file + * @brief Conversions on vectors stored in registers. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +namespace detail { + +// i am not smart enough to figure out these indices without these helpers :/ +// again, blame nvidia for these stupid, stupid layouts +__device__ static inline int row_from_indices_dim2(int laneid, int inner_dim, int x_or_y) { + return 8*inner_dim + (laneid%4)*2 + x_or_y; +} +__device__ static inline int row_from_indices_dim1(int laneid, int x_or_y) { + return 8*x_or_y + (laneid/4); +} +__device__ static inline int canonical_src_lane_dim2(int row) { + return (row/2)%4 + 4*(row%2); // draw even rows from 0...3 and odds from 4...7 +} +__device__ static inline int canonical_src_lane_dim1(int row) { + return (row*4)%32; +} + +} + +/** + * @brief Copies data from one register vector to another. + * + * @tparam RV1 The type of the destination register vector. + * @tparam RV2 The type of the source register vector. + * @param dst[out] The destination register vector. + * @param src[in] The source register vector to copy from. + */ +template +__device__ static inline void copy(RV1 &dst, const RV2 &src) { + static_assert(RV1::outer_dim == RV2::outer_dim, "Outer dimensions of the register vectors must be the same."); + using D1 = RV1::dtype; + using D2 = RV2::dtype; + if constexpr (RV1::inner_dim == RV2::inner_dim) { + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + #pragma unroll + for(int j = 0; j < RV1::inner_dim; j++) { + dst[i][j] = base_types::convertor::convert(src[i][j]); + } + } + } + // Inner dimensions are not the same, this is really a layout conversion. + else if constexpr (RV1::inner_dim == 1 && RV2::inner_dim == 2) { + // Convert from an unaligned vector layout to an aligned vector layout. + int laneid = kittens::laneid(); + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, + laneid < 4 ? src[i][0].x : src[i][0].y, // mirrors canonical_src_lane_dim2 + detail::canonical_src_lane_dim2(detail::row_from_indices_dim1(laneid, 0)) + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, + laneid < 4 ? src[i][1].x : src[i][1].y, // mirrors canonical_src_lane_dim2 + detail::canonical_src_lane_dim2(detail::row_from_indices_dim1(laneid, 1)) + ); + } + } + else if constexpr (RV1::inner_dim == 2 && RV2::inner_dim == 1) { + // Convert from an aligned vector layout to an unaligned vector layout. + int laneid = kittens::laneid(); + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].x, // first 8 rows + detail::canonical_src_lane_dim1(detail::row_from_indices_dim2(laneid, 0, 0)) + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].x, // first 8 rows + detail::canonical_src_lane_dim1(detail::row_from_indices_dim2(laneid, 0, 1)) + ); + dst[i][1].x = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].y, // last 8 rows + detail::canonical_src_lane_dim1(detail::row_from_indices_dim2(laneid, 1, 0)) + ); + dst[i][1].y = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].y, // last 8 rows + detail::canonical_src_lane_dim1(detail::row_from_indices_dim2(laneid, 1, 1)) + ); + } + } + else { + static_assert(RV1::inner_dim == RV2::inner_dim, "Something has gone deeply wrong with how register vectors were instantiated."); + } +} + +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/vec/maps.cuh b/triteia/csrc/flash_kittens/ops/warp/register/vec/maps.cuh new file mode 100644 index 0000000..7a5f9f3 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/vec/maps.cuh @@ -0,0 +1,270 @@ +/** + * @file + * @brief Maps on vectors stored in registers. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- Vector Maps ---------- */ + +/** + * @brief Perform a unary operation on a vector. + * + * @tparam op The unary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector to perform the operation on. + */ +template +__device__ static inline void unary_op(T &dst, const T &src) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(src[i][j]); + } + } +} +/** + * @brief Perform a binary operation on two vectors. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vectors. + * @param dst[out] The destination vector where the result is stored. + * @param lhs[in] The left-hand side vector for the operation. + * @param rhs[in] The right-hand side vector for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(lhs[i][j], rhs[i][j]); + } + } +} +/** + * @brief Perform a binary operation on a vector and a scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(src[i][j], param); + } + } +} +/** + * @brief Perform a binary operation on a vector and an unpacked scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The unpacked scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename base_types::packing::unpacked_type ¶m) { + bin_op(dst, src, base_types::packing::pack(param)); +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// ---- const ops ---- + +/** + * @brief Sets all elements of a register vector to zero. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to zero. + */ +template +__device__ static inline void zero(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to one. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to one. + */ +template +__device__ static inline void one(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to positive infinity. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to positive infinity. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to negative infinity. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to negative infinity. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_op(dst, dst); +} + +// ---- unary ops ---- + +/** + * @brief Copies the elements from one register vector to another. + * + * @tparam T Register vector type. + * @tparam U Type of the source vector. + * @param dst[out] Destination vector where the elements will be copied to. + * @param src[in] Source vector to copy the elements from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_op(dst, dst, src); // the second arg is ignored here. +} +/** + * @brief Applies the exponential function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the natural logarithm function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the absolute value function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the absolute values will be stored. + * @param src[in] Source vector to apply the absolute value function to. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the rectified linear unit (ReLU) function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the ReLU values will be stored. + * @param src[in] Source vector to apply the ReLU function to. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_op(dst, src); +} + +// ---- binary ops ---- + +/** + * @brief Computes the element-wise maximum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the maximum values will be stored. + * @param lhs[in] First vector for the maximum operation. + * @param rhs[in] Second vector for the maximum operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise minimum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the minimum values will be stored. + * @param lhs[in] First vector for the minimum operation. + * @param rhs[in] Second vector for the minimum operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise sum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the sum values will be stored. + * @param lhs[in] First vector for the sum operation. + * @param rhs[in] Second vector for the sum operation. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise difference of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the difference values will be stored. + * @param lhs[in] First vector for the difference operation. + * @param rhs[in] Second vector for the difference operation. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise product of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the product values will be stored. + * @param lhs[in] First vector for the product operation. + * @param rhs[in] Second vector for the product operation. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise division of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the division values will be stored. + * @param lhs[in] First vector for the division operation. + * @param rhs[in] Second vector for the division operation. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} + +} diff --git a/triteia/csrc/flash_kittens/ops/warp/register/vec/reductions.cuh b/triteia/csrc/flash_kittens/ops/warp/register/vec/reductions.cuh new file mode 100644 index 0000000..a9fe8dc --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/vec/reductions.cuh @@ -0,0 +1,180 @@ +/** + * @file + * @brief Reductions on vectors stored in registers. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- Vector Reductions ---------- */ + +/** + * @brief Performs a reduction operation on elements of a register vector within a warp. + * + * This function applies a specified operation to reduce the elements of a register vector `src` to a single value. + * The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`. + * The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp. + * + * @tparam op The operation to perform on the elements. Must provide a static `op` method. + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @tparam reset A boolean flag indicating whether to include an initial value in the reduction. + * @param[out] accum The result of the reduction operation. + * @param[in] src The register vector to reduce. + * @param[in] src_accum The initial value to include in the reduction if `reset` is false. + */ +template +__device__ static inline void reduce( + typename base_types::packing::unpacked_type &dst_accum, + const RV &src, + const typename base_types::packing::unpacked_type &src_accum) { + using T = base_types::packing::unpacked_type; + int laneid = kittens::laneid(); + if constexpr (RV::inner_dim == 1) { + T accum = op::template op(src[0][0].x, src[0][0].y); + #pragma unroll + for(int i = 1; i < src.outer_dim; i++) { + accum = op::template op(accum, src[i][0].x); + accum = op::template op(accum, src[i][0].y); + } + // we've now reduced everything into 8 distinct values, replicated across lanes x, x+1, x+2, x+3 for x≡0(mod4) + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4)); + // we've now reduced everything into 1 distinct value, replicated across lanes 0, 1, 2, 3 + if constexpr (!reset) accum = op::template op(accum, src_accum); + // final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads. + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); + } + else if constexpr (RV::inner_dim == 2) { + T accum = op::template op(src[0][0].x, src[0][0].y); + accum = op::template op(accum, src[0][1].x); + accum = op::template op(accum, src[0][1].y); + #pragma unroll + for(int i = 1; i < src.outer_dim; i++) { + // it is possible that shfl_sync's would be faster but I doubt it, replication is likely better. Certainly simpler. + accum = op::template op(accum, src[i][0].x); + accum = op::template op(accum, src[i][0].y); + accum = op::template op(accum, src[i][1].x); + accum = op::template op(accum, src[i][1].y); + } + // we've now reduced everything into 4 distinct values, replicated across lanes x, x+4, x+8, ..., x+28 for x<4 + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1)); + // we've now reduced everything into 1 distinct value, replicated across lanes 0, 4, 8, 12, ..., 28 + if constexpr (!reset) accum = op::template op(accum, src_accum); + // final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads from lane 0 + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); + } + else { + static_assert(RV::inner_dim==1 || RV::inner_dim==2, "RV's can only have an inner dimension of 1 or 2!"); + } +} + + +/** + * @brief Finds the maximum element in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] max_val The maximum value found in the vector. + * @param[in] src The register vector to find the maximum in. + */ +template +__device__ static inline void max(typename base_types::packing::unpacked_type &max_val, const RV &src) { + reduce(max_val, src, max_val); +} + +/** + * @brief Finds the minimum element in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] min_val The minimum value found in the vector. + * @param[in] src The register vector to find the minimum in. + */ +template +__device__ static inline void min(typename base_types::packing::unpacked_type &min_val, const RV &src) { + reduce(min_val, src, min_val); +} + +/** + * @brief Calculates the sum of elements in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] sum_val The sum of the values in the vector. + * @param[in] src The register vector to sum. + */ +template +__device__ static inline void sum(typename base_types::packing::unpacked_type &sum_val, const RV &src) { + reduce(sum_val, src, sum_val); +} + +/** + * @brief Calculates the product of elements in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] prod_val The product of the values in the vector. + * @param[in] src The register vector to multiply. + */ +template +__device__ static inline void prod(typename base_types::packing::unpacked_type &prod_val, const RV &src) { + reduce(prod_val, src, prod_val); +} + +// Three operand versions. + +/** + * @brief Finds the maximum element in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] max_val The maximum value found in the vector, accumulated with src_accum. + * @param[in] src The register vector to find the maximum in. + * @param[in] src_accum The initial value to accumulate with the maximum value found. + */ +template +__device__ static inline void max(typename base_types::packing::unpacked_type &max_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(max_val, src, src_accum); +} + +/** + * @brief Finds the minimum element in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] min_val The minimum value found in the vector, accumulated with src_accum. + * @param[in] src The register vector to find the minimum in. + * @param[in] src_accum The initial value to accumulate with the minimum value found. + */ +template +__device__ static inline void min(typename base_types::packing::unpacked_type &min_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(min_val, src, src_accum); +} + +/** + * @brief Calculates the sum of elements in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] sum_val The sum of the values in the vector, accumulated with src_accum. + * @param[in] src The register vector to sum. + * @param[in] src_accum The initial value to accumulate with the sum of the vector. + */ +template +__device__ static inline void sum(typename base_types::packing::unpacked_type &sum_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(sum_val, src, src_accum); +} + +/** + * @brief Calculates the product of elements in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] prod_val The product of the values in the vector, accumulated with src_accum. + * @param[in] src The register vector to multiply. + * @param[in] src_accum The initial value to accumulate with the product of the vector. + */ +template +__device__ static inline void prod(typename base_types::packing::unpacked_type &prod_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(prod_val, src, src_accum); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/register/vec/vec.cuh b/triteia/csrc/flash_kittens/ops/warp/register/vec/vec.cuh new file mode 100644 index 0000000..279f307 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/register/vec/vec.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header for warp operations on register vectors. + */ + +#pragma once + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/shared.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/shared.cuh new file mode 100644 index 0000000..0b7b372 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/shared.cuh @@ -0,0 +1,9 @@ +/** + * @file + * @brief An aggregate header of warp operations on data in shared memory + */ + +#pragma once + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/tile/conversions.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/tile/conversions.cuh new file mode 100644 index 0000000..de1ae73 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/tile/conversions.cuh @@ -0,0 +1,60 @@ +/** + * @file + * @brief Conversions between shared tile types. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- COPIES ---------- */ + +/** + * @brief Copies data from one shared memory tile to another, potentially with different data types and layouts. + * + * @tparam T The data type of the destination tile. + * @tparam U The data type of the source tile. + * @tparam _height The height of the tile. + * @tparam _width The width of the tile. + * @tparam L1 The layout of the destination tile. + * @tparam L2 The layout of the source tile. + * @param[out] dst The destination tile. + * @param[in] src The source tile. + */ +template +__device__ static inline void copy(st &dst, const st &src) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i+=kittens::WARP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = base_types::convertor::convert(src[{row, col}]); + } +} + +/* ---------- SUBTILE ---------- */ + +/** +* @brief Returns a reference to a subtile of the given shared tile. +* +* @tparam subtile_height The height of the subtile. +* @tparam subtile_width The width of the subtile. +* @tparam ST The type of the input tile, which must satisfy the ducks::st::all concept. +* @param src The input tile. +* @param row_idx The row index of the subtile, in units of subtile_height*16 elements. +* @param col_idx The col index of the subtile, in units of subtile_width*16 elements. +* @return A reference to the subtile. +* +* @note The subtile {height, width} must evenly divide the tile {height, width}. +*/ +template +__device__ inline typename ST::subtile subtile_inplace(ST &src, int row_idx, int col_idx) { + static_assert(ST::height % subtile_height == 0); + static_assert(ST::width % subtile_width == 0); + return typename ST::subtile( + &src[0], subtile_height*16*row_idx, subtile_width*16*col_idx + ); +} + +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/tile/maps.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/tile/maps.cuh new file mode 100644 index 0000000..4b72c11 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/tile/maps.cuh @@ -0,0 +1,445 @@ +/** + * @file + * @brief Warp-scope maps on shared tiles. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/* ---------- Uniform tile maps (independent of layout) ---------- */ + +/** + * @brief Performs a uniform unary operation on a tile. + * + * This function applies a given unary operation to each element of the source tile and stores the result in the destination tile. + * The operation is applied independently to each element, without considering its position or the values of neighboring elements. + * + * @tparam op The unary operation to be applied. Must be specialized to support operation on the data type of T. + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the unary operation is applied. + */ +template // T2, w, h can be inferred from dst as long as op is specialized +__device__ static inline void unary_map(T &dst, const T &src) { + #pragma unroll + for(int i = kittens::laneid(); i < dst.num_elements; i += WARP_THREADS) { + dst.data[i] = op::template op(src.data[i]); + } +} + +/** + * @brief Performs a uniform binary operation on a tile with a scalar parameter. + * + * This function applies a given binary operation to each element of the source tile and a scalar parameter, then stores the result in the destination tile. + * The operation is applied independently to each element, treating the scalar parameter as the second operand for each operation. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the scalar parameter. + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] param The scalar parameter to be used as the second operand in the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = kittens::laneid(); i < dst.num_elements; i += WARP_THREADS) { + dst.data[i] = op::template op(src.data[i], param); + } +} + +/** + * @brief Performs a uniform binary operation on two tiles. + * + * This function applies a given binary operation to corresponding elements of two source tiles and stores the result in the destination tile. + * The operation is applied independently to each pair of elements, without considering their positions or the values of neighboring elements. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile to which the binary operation is applied. + * @param[in] rhs The second source tile to which the binary operation is applied. + */ +template +__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = kittens::laneid(); i < dst.num_elements; i += WARP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst.data[i] = op::template op(lhs.data[i], rhs.data[i]); + } +} + +/** + * @brief Performs a row-wise binary operation on a tile with a vector. + * + * This function applies a given binary operation to each row of the source tile and the corresponding element of the source vector, + * then stores the result in the destination tile. The operation is applied independently to each row, using the vector element as + * the second operand for each element in the row. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @tparam V The type of the vector. Must have the same data type as T. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] vec The source vector containing the second operand for each row operation. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::rows, "Vector length must match the number of rows in the tile"); + #pragma unroll + for(int i = kittens::laneid(); i < dst.num_elements; i += WARP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[row]); + } +} + +/** + * @brief Performs a column-wise binary operation on a tile with a vector. + * + * This function applies a given binary operation to each column of the source tile and the corresponding element of the source vector, + * then stores the result in the destination tile. The operation is applied independently to each column, using the vector element as + * the second operand for each element in the column. + * + * @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements. + * @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept. + * @tparam V The type of the vector. Must have the same data type as T. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the binary operation is applied. + * @param[in] vec The source vector containing the second operand for each column operation. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::cols, "Vector length must match the number of columns in the tile"); + #pragma unroll + for(int i = kittens::laneid(); i < dst.num_elements; i += WARP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[col]); + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// All of the annoying qualifiers *should* be automatically inferred during compile-time. +// So, syntax should just be kittens::add_row(tile, colvec); + +// const maps +/** + * @brief Sets all elements of the destination tile to zero. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void zero(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to one. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void one(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to positive infinity. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of the destination tile to negative infinity. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_map(dst, dst); +} + +// unary maps +/** + * @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the exponential function is applied. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the natural logarithm function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the natural logarithm function is applied. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the absolute function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the absolute function is applied. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Applies the rectified linear unit function to each element of the source tile and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source tile to which the rectified linear unit function is applied. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_map(dst, src); +} +/** + * @brief Copies the elements of the source tile to the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] src The source data to be copied. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_map(dst, src); +} + +// uniform binary maps +/** + * @brief Finds the maximum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Finds the minimum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Adds each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Subtracts each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Multiplies each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Divides each pair of corresponding elements in the two source tiles and stores the result in the destination tile. + * + * @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept. + * @tparam U The type of the second source data. Must be convertible to the data type of the destination tile. + * @param[out] dst The destination tile where the results are stored. + * @param[in] lhs The first source tile. + * @param[in] rhs The second source data. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +// Row and col maps + +/** + * @brief Adds row values to each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param row_values[in] Column vector containing values to add to each row. + */ +template +__device__ static inline void add_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Subtracts row values from each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param row_values[in] Column vector containing values to subtract from each row. + */ +template +__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Multiplies each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param row_values[in] Column vector containing values to multiply each row by. + */ +template +__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Divides each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param row_values[in] Column vector containing values to divide each row by. + */ +template +__device__ static inline void div_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} +/** + * @brief Broadcast a vector into into a tile's rows. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Column vector containing values to broadcast into rows. + */ +template +__device__ static inline void broadcast_row(T &dst, const V &row_values) { + row_map(dst, dst, row_values); +} + + +// col maps +/** + * @brief Adds column values to each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param col_values[in] Row vector containing values to add to each column. + */ +template +__device__ static inline void add_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Subtracts column values from each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param col_values[in] Row vector containing values to subtract from each column. + */ +template +__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Multiplies each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param col_values[in] Row vector containing values to multiply each column by. + */ +template +__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Divides each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param col_values[in] Row vector containing values to divide each column by. + */ +template +__device__ static inline void div_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} +/** + * @brief Broadcast a vector into into a tile's columns. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Row vector containing values to broadcast into cols. + */ +template +__device__ static inline void broadcast_col(T &dst, const V &col_values) { + col_map(dst, dst, col_values); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/tile/reductions.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/tile/reductions.cuh new file mode 100644 index 0000000..9a57aae --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/tile/reductions.cuh @@ -0,0 +1,277 @@ +/** + * @file + * @brief Warp-scope reductions on shared tiles. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * Performs row-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type with row layout. + * @param row_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int row = kittens::laneid(); row < src.rows; row += kittens::WARP_THREADS) { + dtype accum = src[{row, 0}]; + #pragma unroll + for (int col = 1; col < src.cols; col++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + row_accum[row] = accum; + } else { + row_accum[row] = op::template op(src_accum[row], accum); + } + } + __syncwarp(); +} + +/** + * Performs column-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the column accumulator. + * @tparam T The shared matrix type with column layout. + * @param col_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int col = kittens::laneid(); col < src.cols; col += kittens::WARP_THREADS) { + dtype accum = src[{0, col}]; + #pragma unroll + for (int row = 1; row < src.rows; row++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + col_accum[col] = accum; + } else { + col_accum[col] = op::template op(src_accum[col], accum); + } + } + __syncwarp(); +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +/** + * @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the product of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} + +/** + * @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the sum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the product of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/tile/tile.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/tile/tile.cuh new file mode 100644 index 0000000..0ba0bc0 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/tile/tile.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header for warp operations on shared tiles. + */ + +#pragma once + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/vec/conversions.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/vec/conversions.cuh new file mode 100644 index 0000000..737cc78 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/vec/conversions.cuh @@ -0,0 +1,55 @@ +/** + * @file + * @brief Warp-scope conversions on shared vectors. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + + +namespace kittens { + +/** + * @brief Copies data from one shared vector to another, converting data types if necessary. + * + * This function copies data from the source shared vector `src` to the destination shared vector `dst`. + * If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it + * converts each element from the source data type to the destination data type using the appropriate + * converter before copying. + * + * @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept. + * @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept. + * @param[out] dst The destination shared vector. + * @param[in] src The source shared vector. + * @note The lengths of `src` and `dst` must be equal. This is enforced at compile time. + */ +template +__device__ static inline void copy(SV1 &dst, const SV2 &src) { + static_assert(dst.length == src.length, "Source and destination vectors must have the same length."); + #pragma unroll + for(int i = kittens::laneid(); i < dst.length; i+=WARP_THREADS) { + dst[i] = base_types::convertor::convert(src[i]); + } +} + +/* ---------- SUBVEC ---------- */ + +/** +* @brief Returns a reference to a subvec of a given shared vector +* +* @tparam subvec_tiles The length, in subtiles, of the subvec. +* @tparam SV The type of the input vector, which must satisfy the ducks::sv::all concept. +* @param src The input tile. +* @param vec_idx The index of the subtile, in units of subvec_tiles*16 elements. +* @return A reference to the subvec. +* +* @note The subvec length must evenly divide the vector length. +*/ +template +__device__ inline typename SV::subvec &subvec_inplace(SV &src, int vec_idx) { + return *(typename SV::subvec*)(&src[vec_idx*kittens::TILE_DIM*subvec_tiles]); +} + +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/vec/maps.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/vec/maps.cuh new file mode 100644 index 0000000..c8b74d1 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/vec/maps.cuh @@ -0,0 +1,250 @@ +/** + * @file + * @brief Warp-scope maps on shared vectors. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + + +namespace kittens { + +/** + * @brief Applies a unary operation to each element of a shared memory vector. + * + * @tparam op Unary operation type. + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector in which to store the result. + * @param src[in] Source vector to apply the unary operation. + */ +template +__device__ static inline void unary_op(T &dst, const T &src) { + __syncwarp(); + #pragma unroll + for(int cur = kittens::laneid(); cur < T::length; cur+=WARP_THREADS) { + dst[cur] = op::template op(src[cur]); + } +} +/** + * @brief Perform a binary operation on two shared vectors. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vectors. + * @param dst[out] The destination vector where the result is stored. + * @param lhs[in] The left-hand side vector for the operation. + * @param rhs[in] The right-hand side vector for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) { + __syncwarp(); + #pragma unroll + for(int cur = laneid(); cur < T::length; cur+=WARP_THREADS) { + dst[cur] = op::template op(lhs[cur], rhs[cur]); + } +} +/** + * @brief Perform a binary operation on a shared vector and a scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) { + __syncwarp(); + #pragma unroll + for(int cur = laneid(); cur < T::length; cur+=WARP_THREADS) { + dst[cur] = op::template op(src[cur], param); + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// ---- const ops ---- + +/** + * @brief Sets all elements of a shared memory vector to zero. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to zero. + */ +template +__device__ static inline void zero(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to one. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to one. + */ +template +__device__ static inline void one(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to positive infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to positive infinity. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to negative infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to negative infinity. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_op(dst, dst); +} + +// ---- unary ops ---- + +/** + * @brief Copies the elements from one shared vector to another. + * + * @tparam T Shared vector type. + * @tparam U Type of the source vector. + * @param dst[out] Destination vector where the elements will be copied to. + * @param src[in] Source vector to copy the elements from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_op(dst, dst, src); // the second arg is ignored here. +} +/** + * @brief Applies the exponential function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the natural logarithm function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the logarithm values will be stored. + * @param src[in] Source vector to apply the logarithm function to. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the absolute value function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the absolute values will be stored. + * @param src[in] Source vector to apply the absolute value function to. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the ReLU values will be stored. + * @param src[in] Source vector to apply the ReLU function to. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_op(dst, src); +} + +// ---- binary ops ---- + +/** + * @brief Computes the element-wise maximum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the maximum values will be stored. + * @param lhs[in] First vector for the maximum operation. + * @param rhs[in] Second vector for the maximum operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise minimum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the minimum values will be stored. + * @param lhs[in] First vector for the minimum operation. + * @param rhs[in] Second vector for the minimum operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise sum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the sum values will be stored. + * @param lhs[in] First vector for the sum operation. + * @param rhs[in] Second vector for the sum operation. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise difference of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the difference values will be stored. + * @param lhs[in] First vector for the difference operation. + * @param rhs[in] Second vector for the difference operation. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise product of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the product values will be stored. + * @param lhs[in] First vector for the product operation. + * @param rhs[in] Second vector for the product operation. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise division of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the division values will be stored. + * @param lhs[in] First vector for the division operation. + * @param rhs[in] Second vector for the division operation. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} + +} diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/vec/reductions.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/vec/reductions.cuh new file mode 100644 index 0000000..6518878 --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/vec/reductions.cuh @@ -0,0 +1,159 @@ +/** + * @file + * @brief Warp-scope reductions on shared vectors. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { + +/** + * @brief Performs a reduction operation on elements of a shared memory vector within a warp. + * + * This function applies a specified operation to reduce the elements of a shared memory vector `src` to a single value. + * The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`. + * The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp. + * + * @tparam op The operation to perform on the elements. Must provide a static `op` method. + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @tparam reset A boolean flag indicating whether to include an initial value in the reduction. + * @param[out] accum The result of the reduction operation. + * @param[in] src The shared memory vector to reduce. + * @param[in] src_accum The initial value to include in the reduction if `reset` is false. + */ +template +__device__ static inline void reduce(typename SV::dtype &dst_accum, const SV &src, const typename SV::dtype &src_accum) { + using T = SV::dtype; + int laneid = kittens::laneid(); + T accum; + if(laneid < src.length) accum = src[laneid]; // initialize a register accumulator + __syncwarp(); + for(int i = laneid+kittens::WARP_THREADS; i < src.length; i+=kittens::WARP_THREADS) { + accum = op::template op(accum, src[i]); + } + __syncwarp(); + // We can now reduce within the warp. + if constexpr (src.length > 16) { + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16)); + __syncwarp(); + } + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1)); + __syncwarp(); + if constexpr (!reset) accum = op::template op(accum, src_accum); + // broadcast to all threads in the warp. + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); // everyone takes from warp leader +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +/** + * @brief Finds the maximum element in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] max_val The maximum value found in the vector. + * @param[in] src The shared memory vector to find the maximum in. + */ +template +__device__ static inline void max(typename SV::dtype &max_val, const SV &src) { + reduce(max_val, src, max_val); +} + +/** + * @brief Finds the minimum element in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] min_val The minimum value found in the vector. + * @param[in] src The shared memory vector to find the minimum in. + */ +template +__device__ static inline void min(typename SV::dtype &min_val, const SV &src) { + reduce(min_val, src, min_val); +} + +/** + * @brief Calculates the sum of elements in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] sum_val The sum of the values in the vector. + * @param[in] src The shared memory vector to sum. + */ +template +__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src) { + reduce(sum_val, src, sum_val); +} + +/** + * @brief Calculates the product of elements in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] prod_val The product of the values in the vector. + * @param[in] src The shared memory vector to multiply. + */ +template +__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src) { + reduce(prod_val, src, prod_val); +} + +// Three operand versions. + +/** + * @brief Finds the maximum element in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] max_val The maximum value found in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to find the maximum in. + * @param[in] src_accum The initial value to accumulate with the maximum value found. + */ +template +__device__ static inline void max(typename SV::dtype &max_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(max_val, src, src_accum); +} + +/** + * @brief Finds the minimum element in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] min_val The minimum value found in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to find the minimum in. + * @param[in] src_accum The initial value to accumulate with the minimum value found. + */ +template +__device__ static inline void min(typename SV::dtype &min_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(min_val, src, src_accum); +} + +/** + * @brief Calculates the sum of elements in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] sum_val The sum of the values in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to sum. + * @param[in] src_accum The initial value to accumulate with the sum of the vector. + */ +template +__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(sum_val, src, src_accum); +} + +/** + * @brief Calculates the product of elements in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] prod_val The product of the values in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to multiply. + * @param[in] src_accum The initial value to accumulate with the product of the vector. + */ +template +__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(prod_val, src, src_accum); +} +} \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/shared/vec/vec.cuh b/triteia/csrc/flash_kittens/ops/warp/shared/vec/vec.cuh new file mode 100644 index 0000000..9d78bcd --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/shared/vec/vec.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header for warp operations on data stored in shared memory. + */ + +#pragma once + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/ops/warp/warp.cuh b/triteia/csrc/flash_kittens/ops/warp/warp.cuh new file mode 100644 index 0000000..4f43bab --- /dev/null +++ b/triteia/csrc/flash_kittens/ops/warp/warp.cuh @@ -0,0 +1,13 @@ +/** + * @file + * @brief An aggregate header of all warp (worker) operations defined by ThunderKittens + */ + +#pragma once + +// no namespace wrapper needed here +// as warp is the default op scope! + +#include "register/register.cuh" +#include "shared/shared.cuh" +#include "memory/memory.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/types/register/register.cuh b/triteia/csrc/flash_kittens/types/register/register.cuh new file mode 100644 index 0000000..de034ac --- /dev/null +++ b/triteia/csrc/flash_kittens/types/register/register.cuh @@ -0,0 +1,11 @@ +/** + * @file + * @brief An aggregate header file for all the register types defined by ThunderKittens. + */ + +#pragma once + +#include "rt_layout.cuh" +#include "rt_base.cuh" +#include "rv.cuh" +#include "rt.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/types/register/rt.cuh b/triteia/csrc/flash_kittens/types/register/rt.cuh new file mode 100644 index 0000000..0ff970a --- /dev/null +++ b/triteia/csrc/flash_kittens/types/register/rt.cuh @@ -0,0 +1,190 @@ +/** + * @file + * @brief The main ThunderKittens register tile struct, where most computation happens. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" + +#include "rt_layout.cuh" +#include "rt_base.cuh" +#include "rv.cuh" + +namespace kittens { + +/* ---------- MAIN TILE STRUCT ---------- */ + +// helper struct for type inference +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for register tiles live. + */ +namespace rt { +/** + * @brief A dummy type used to identify register tiles. + * + * For a type to quack like an rt, it should define its identifier as ducks::rt::identifier. + * If a type quacks like ducks::rt::identifier, it will be treated as an rt by compiler checks. + */ +struct identifier {}; +} // namespace rt +} // namespace ducks + +/** + * @brief Main tile structure for manipulating data in registers. + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _height The height of the tile in terms of the number of subtiles. + * @tparam _width The width of the tile in terms of the number of subtiles. + * @tparam _layout The layout of the internal base tiles, either row-major or column-major. + * + * This structure is designed to handle matrix tiles in a flexible manner, allowing + * for operations on tiles that are composed of smaller subtiles. It supports both + * row-major and column-major layouts and includes helper structs for type inference + * in vector maps. + * + * In general, you probably want a row-major tile, unless you specifically want to call mma + */ +template +struct rt { + using identifier = ducks::rt::identifier; ///< Type identifier for the rt structure. + using layout = _layout; ///< Layout of the matrix tile. + using dtype = T2; ///< Data type of the matrix elements. + + static constexpr int height = _height; ///< Height in subtiles. + static constexpr int width = _width; ///< Width in subtiles. + static constexpr int rows = height * rt_base::tile_size; ///< Total number of rows. + static constexpr int cols = width * rt_base::tile_size; ///< Total number of columns. + static constexpr int tile_size = rt_base::tile_size; ///< Size of the base tile. + static constexpr int num_elements = rt_base::num_elements * width * height; ///< Total number of elements. + static constexpr int elements_per_thread = rt_base::elements_per_thread * width * height; ///< Elements handled per thread. + static constexpr int packed_per_thread = rt_base::packed_per_thread * width * height; ///< Packed elements per thread. + static constexpr int packed_per_tile = rt_base::packed_per_thread; ///< Packed elements per tile. + + rt_base tiles[height][width]; ///< The actual storage for the matrix tile, organized in subtiles. + + using col_vec = rv::col_vec_pack>; ///< A type representing a column vector for this tile. + using row_vec = rv::row_vec_pack>; ///< A type representing a column vector for this tile. +}; + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace rt { +/** +* @brief Concept for all register tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rt::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rt::identifier +/** +* @brief Concept for register tiles with row layout. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::row. +*/ +template +concept row_layout = all && std::is_same_v; +/** +* @brief Concept for register tiles with col layout. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::col. +*/ +template +concept col_layout = all && std::is_same_v; + +} // namespace rt +} // namespace ducks + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// layout and type wrappers + +template using rt_fl = rt; +template using rt_bf = rt; +template using rt_hf = rt; + +// layout, type, and size wrappers +// sizes are chosen with the assumption that you aren't going to want to fit more than +// 8 subtiles on a warp. (Could be wrong!) + +/// 8 registers used +template using rt_fl_1x1 = rt_fl<1, 1, layout>; +/// 16 registers used +template using rt_fl_1x2 = rt_fl<1, 2, layout>; +/// 32 registers used +template using rt_fl_1x4 = rt_fl<1, 4, layout>; +/// 64 registers used +template using rt_fl_1x8 = rt_fl<1, 8, layout>; +/// 16 registers used +template using rt_fl_2x1 = rt_fl<2, 1, layout>; +/// 32 registers used +template using rt_fl_2x2 = rt_fl<2, 2, layout>; +/// 64 registers used +template using rt_fl_2x4 = rt_fl<2, 4, layout>; +/// 32 registers used +template using rt_fl_4x1 = rt_fl<4, 1, layout>; +/// 64 registers used +template using rt_fl_4x2 = rt_fl<4, 2, layout>; +/// 64 registers used +template using rt_fl_8x1 = rt_fl<8, 1, layout>; + +/// 4 registers used +template using rt_bf_1x1 = rt_bf<1, 1, layout>; +/// 8 registers used +template using rt_bf_1x2 = rt_bf<1, 2, layout>; +/// 16 registers used +template using rt_bf_1x4 = rt_bf<1, 4, layout>; +/// 32 registers used +template using rt_bf_1x8 = rt_bf<1, 8, layout>; +/// 8 registers used +template using rt_bf_2x1 = rt_bf<2, 1, layout>; +/// 16 registers used +template using rt_bf_2x2 = rt_bf<2, 2, layout>; +/// 32 registers used +template using rt_bf_2x4 = rt_bf<2, 4, layout>; +/// 16 registers used +template using rt_bf_4x1 = rt_bf<4, 1, layout>; +/// 32 registers used +template using rt_bf_4x2 = rt_bf<4, 2, layout>; +/// 32 registers used +template using rt_bf_8x1 = rt_bf<8, 1, layout>; + +/// 4 registers used +template using rt_hf_1x1 = rt_hf<1, 1, layout>; +/// 8 registers used +template using rt_hf_1x2 = rt_hf<1, 2, layout>; +/// 16 registers used +template using rt_hf_1x4 = rt_hf<1, 4, layout>; +/// 32 registers used +template using rt_hf_1x8 = rt_hf<1, 8, layout>; +/// 8 registers used +template using rt_hf_2x1 = rt_hf<2, 1, layout>; +/// 16 registers used +template using rt_hf_2x2 = rt_hf<2, 2, layout>; +/// 32 registers used +template using rt_hf_2x4 = rt_hf<2, 4, layout>; +/// 16 registers used +template using rt_hf_4x1 = rt_hf<4, 1, layout>; +/// 32 registers used +template using rt_hf_4x2 = rt_hf<4, 2, layout>; +/// 32 registers used +template using rt_hf_8x1 = rt_hf<8, 1, layout>; + +} // namespace kittens diff --git a/triteia/csrc/flash_kittens/types/register/rt_base.cuh b/triteia/csrc/flash_kittens/types/register/rt_base.cuh new file mode 100644 index 0000000..895b76d --- /dev/null +++ b/triteia/csrc/flash_kittens/types/register/rt_base.cuh @@ -0,0 +1,92 @@ +/** + * @file + * @brief The basic 16x16 register tile on which larger register tiles are built. + */ + +#pragma once + +#include + +#include "../../common/common.cuh" +#include "rt_layout.cuh" + +namespace kittens { + +/* ---------- BASE 16x16 SUBTILE STRUCT ---------- */ + +namespace ducks { +/** + * @namespace rt_base + * + * @brief The namespace where concepts and abstract types for register base (16x16) tiles live. + */ +namespace rt_base { +/** + * @brief A dummy type used to identify register base tiles. + * + * For a type to quack like an rt_base, it should define its identifier as ducks::rt_base::identifier. + * If a type quacks like ducks::rt_base::identifier, it will be treated as an rt_base by compiler checks. + */ +struct identifier {}; +} +} // namespace ducks + +/** + * @brief Basic tile structure for computation in registers. + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _layout The layout of the base tile, either row-major or column-major. + * + * This type is a primarily utility for building larger inline templates + * out of PTX primitives and managing layouts. + * + * In general, you probably want a row-major tile, unless you specifically want to call mma + */ +template struct rt_base { + using identifier = ducks::rt_base::identifier; ///< Type identifier for the rt_base structure. + using layout = _layout; ///< Layout of the matrix tile. + using dtype = T2; ///< Data type of the matrix elements + + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v, + "rt_base was provided an unsupported type." + ); + + static constexpr int tile_size = 16; ///< Tile size is a constant 16. + static constexpr int rows = tile_size; ///< Number of rows. + static constexpr int cols = tile_size; ///< Number of cols. + static constexpr int num_elements = rows*cols; // 256 + static constexpr int elements_per_thread = num_elements / 32; // 8 + + static constexpr int packed_per_thread = elements_per_thread / base_types::packing::num(); // 4 + static constexpr int registers_per_thread = packed_per_thread * sizeof(T2) / 4; // 4 or 8, registers are 32-bit words + + static constexpr int col_vec_pack = layout::is_row ? 1 : 2; // for holding row reductions + static constexpr int row_vec_pack = layout::is_row ? 2 : 1; // for holding column reductions + + T2 data[packed_per_thread]; ///< The actual storage for the base tile +}; + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace rt_base { +/** +* @brief Concept for all register base tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rt_base::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rt::identifier +} // namespace rt +} // namespace ducks + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using rt_base_fl = rt_base; // Note float2! Otherwise you will get bugs. +template using rt_base_bf = rt_base; + +} diff --git a/triteia/csrc/flash_kittens/types/register/rt_layout.cuh b/triteia/csrc/flash_kittens/types/register/rt_layout.cuh new file mode 100644 index 0000000..273cb66 --- /dev/null +++ b/triteia/csrc/flash_kittens/types/register/rt_layout.cuh @@ -0,0 +1,42 @@ +/** + * @file + * @brief Layouts and their manipulations for register tiles. + */ + +#pragma once + +#include + +namespace kittens { +namespace ducks { +/** + * @namespace rt_layout + * + * @brief A namespace for template metaprogramming with register tile layouts. + */ +namespace rt_layout { + +/** + * @brief A dummy type used to identify a row-major layout for a register tile. + */ +struct row { static constexpr bool is_row=true; }; // for most matrices +/** + * @brief A dummy type used to identify a col-major layout for a register tile. + */ +struct col { static constexpr bool is_row=false; }; // for the B-matrix of MMA ops. + +/** + * @brief A concept to check if a type is a register tile layout. + */ +template +concept all = std::is_same_v || std::is_same_v; + +/** + * @brief A struct to generate a transposed layout. + */ +template struct transpose { using type = col; }; +template<> struct transpose { using type = row; }; + +} // namespace ducks::rt_layout::all +} // namespace ducks +} // namespace kittens diff --git a/triteia/csrc/flash_kittens/types/register/rv.cuh b/triteia/csrc/flash_kittens/types/register/rv.cuh new file mode 100644 index 0000000..6c74a73 --- /dev/null +++ b/triteia/csrc/flash_kittens/types/register/rv.cuh @@ -0,0 +1,83 @@ +/** + * @file + * @brief Register vectors for computations on axes. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" +#include "rt_layout.cuh" + +namespace kittens { + +/* ---------- MAIN VECTOR STRUCT ---------- */ + +// helper struct for type inference +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for register vectors live. + */ +namespace rv { +/** + * @brief A dummy type used to identify register vectors. + * + * For a type to quack like an rv, it should define its identifier as ducks::rv::identifier. + * If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks. + */ +struct identifier {}; +} +} +/** + * @brief Register vector structure. + * + * @tparam _T The packed data type used for the vector elements. + * @tparam _outer_dim The size of the tile, in units of TILE_DIM (16). + * @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout. + * + * Register vectors are used to accumulate and map values across tiles. You can do computation + * on them directly if you want, but they're not designed to be maximally efficient vectors + * as they have substantial duplication and strange layouts to help them work efficiently with + * the register layouts used by the tensor cores. ThunderKittens wants you working with tiles + * where possible! + */ +template +struct rv { + using identifier = ducks::rv::identifier; ///< Type identifier for the rv structure. + using dtype = _T; ///< Data type of the vector elements. + + static constexpr int outer_dim = _outer_dim; ///< Length in subtiles. + static constexpr int inner_dim = _inner_dim; ///< Internal layout within a subtile. Either 1 or 2. + + dtype data[outer_dim][inner_dim]; ///< The actual register vector data. + + __device__ inline dtype* operator[](size_t idx) { return &data[idx][0]; } ///< A wrapper for indexing into vector data. + __device__ inline const dtype* operator[](size_t idx) const { return &data[idx][0]; } ///< A wrapper for indexing into vector data. + __device__ inline dtype& operator[](int2 outin) { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data. + __device__ inline const dtype& operator[](int2 outin) const { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data. +}; + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace rv { +/** +* @brief Concept for all register vectors. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rv::identifier. +*/ +template +concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rv::identifier. + +} // namespace rv +} // namespace ducks + +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/types/shared/shared.cuh b/triteia/csrc/flash_kittens/types/shared/shared.cuh new file mode 100644 index 0000000..4ee6df2 --- /dev/null +++ b/triteia/csrc/flash_kittens/types/shared/shared.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header file for all the shared types defined by ThunderKittens. + */ + +#pragma once + +#include "st_layout.cuh" +#include "sv.cuh" +#include "st.cuh" \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/types/shared/st.cuh b/triteia/csrc/flash_kittens/types/shared/st.cuh new file mode 100644 index 0000000..ab62dba --- /dev/null +++ b/triteia/csrc/flash_kittens/types/shared/st.cuh @@ -0,0 +1,212 @@ +/** + * @file + * @brief The ThunderKittens shared tile struct. + */ + +#pragma once + +#include "../../common/common.cuh" +#include "st_layout.cuh" +#include "sv.cuh" + +/* ---------- MAIN TILE STRUCT ---------- */ + +// these are helper structs for type inference +namespace kittens { +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for shared tiles live. + */ +namespace st { +/** + * @brief A dummy type used to identify shared tiles. + * + * For a type to quack like an st, it should define its identifier as ducks::st::identifier. + * If a type quacks like ducks::st::identifier, it will be treated as an st by compiler checks. + * This is particularly useful for subtiles on challenging layouts. + */ +struct identifier {}; +} // namespace st +} // namespace ducks + +// Forward declaration of subtile +template< + typename _T, + int _underlying_height, + int _underlying_width, + ducks::st_layout::all _layout, + int _subtile_height, + int _subtile_width +> +struct st_subtile; + +/** + * @brief Shared memory tile structure for various data types and layouts. + * + * @tparam T The data type of the elements in the tile. Not packed! + * @tparam _height The height of the tile in units of 16-element subtiles. + * @tparam _width The width of the tile in units of 16-element subtiles. + * @tparam _layout The memory layout of the tile. + */ +template +struct KITTENS_DEFAULT_ALIGN st { + using identifier = ducks::st::identifier; ///< Type identifier for shared memory tile. + using layout = _layout; ///< Memory layout of the tile. + using dtype = _T; ///< Data type of the elements in the tile. + + // define underlying data as same as that projected, to make clear that this is *not* a subtile. + static constexpr int underlying_height = _height; + static constexpr int underlying_width = _width; + static constexpr int underlying_rows = underlying_height * kittens::TILE_DIM; + static constexpr int underlying_cols = underlying_width * kittens::TILE_DIM; + static constexpr int underlying_num_elements = underlying_rows * underlying_cols; + + static constexpr int height = _height; ///< Height of the tile in terms of 16-element subtiles. + static constexpr int width = _width; ///< Width of the tile in terms of 16-element subtiles. + static constexpr int rows = height * kittens::TILE_DIM; ///< Total number of rows in the tile. + static constexpr int cols = width * kittens::TILE_DIM; ///< Total number of cols in the tile. + static constexpr int num_elements = rows * cols; ///< Total number of elements in the tile. + + static_assert(base_types::packing::num() == 1); // must be a 1-packed type (e.g. float, bf16, etc) + + static constexpr int swizzle_bytes = ( + std::is_same_v || std::is_same_v ? ( + underlying_width%4 == 0 ? 128 : + underlying_width%2 == 0 ? 64 : + 32 + ) : 0 + ); + + // wgmma layout with swizzling + dtype data[rows*cols]; ///< Raw data storage for the tile. + + /** + * @brief Access a shared tile element using a row and column, as if the tile were row-major. + * + * This is the preferred way to access memory within a shared tile, which abstracts + * indexing calculations for swizzled or strangely ordered layouts. + */ + __device__ inline dtype& operator[](const int2 &rowcol) { + return *detail::shared_indexer::idx(data, rowcol.x, rowcol.y); + } + __device__ inline const dtype& operator[](const int2 &rowcol) const { + return *(const bf16*)detail::shared_indexer::idx((bf16*)data, rowcol.x, rowcol.y); + } + __device__ inline dtype& operator[](int idx) { + return data[idx]; + } + __device__ inline const dtype& operator[](int idx) const { + return data[idx]; + } + + // vector types + using col_vec = sv; ///< Column vector type for this tile + using row_vec = sv; ///< Row vector type for this tile + template using subtile = st_subtile< + dtype, height, width, layout, + subtile_height, subtile_width + >; ///< A templated subtile type wrapper for this tile. +}; + + +/** + * @brief A reference into a chunk of shared tile memory. + * + * The st_subtile is a drop-in replacement for an st which internally + * references the appropriate memory while performing minimal address + * calculations. You should never create this directly, but instead + * have subtile_inplace return it for you instead. (`auto` is nice.) + * + * You can generally just pretend this is an st. But not for wgmma's. + */ +template< + typename _T, + int _underlying_height, + int _underlying_width, + ducks::st_layout::all _layout, + int _subtile_height, + int _subtile_width +> +struct st_subtile { + using identifier = ducks::st::identifier; // i quack like an st, gcc will never know the difference + using layout = _layout; + using dtype = _T; + + static constexpr int underlying_height = _underlying_height; + static constexpr int underlying_width = _underlying_width; + static constexpr int underlying_rows = underlying_height * kittens::TILE_DIM; + static constexpr int underlying_cols = underlying_width * kittens::TILE_DIM; + static constexpr int underlying_num_elements = underlying_rows * underlying_cols; + + static constexpr int height = _subtile_height; + static constexpr int width = _subtile_width; + static constexpr int rows = height * kittens::TILE_DIM; + static constexpr int cols = width * kittens::TILE_DIM; + static constexpr int num_elements = rows * cols; + + dtype *data; + int row_offset, col_offset; + + __device__ st_subtile(dtype *src, int _row_offset, int _col_offset) { + data = src; + row_offset = _row_offset; + col_offset = _col_offset; + } + + __device__ inline dtype& operator[](const int2 &rowcol) { + return *detail::shared_indexer::idx( + (bf16*)data, rowcol.x+row_offset, rowcol.y+col_offset + ); + } + __device__ inline const dtype& operator[](const int2 &rowcol) const { + return *(const bf16*)detail::shared_indexer::idx( + (bf16*)data, rowcol.x+row_offset, rowcol.y+col_offset + ); + } + + // single-index operator[] is left undefined as it would likely be an improper use of st_subtile type + + // vector types + using col_vec = sv; + using row_vec = sv; +}; + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace st { + +/** +* @brief Concept for all shared tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as st::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::st::identifier + +} // namespace st +} // namespace ducks + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using st_bf = st; // prelim tests indicate this is fastest default + +template using st_bf_1x1 = st_bf<1, 1, layout>; +template using st_bf_1x2 = st_bf<1, 2, layout>; +template using st_bf_1x4 = st_bf<1, 4, layout>; +template using st_bf_1x8 = st_bf<1, 8, layout>; +template using st_bf_2x1 = st_bf<2, 1, layout>; +template using st_bf_2x2 = st_bf<2, 2, layout>; +template using st_bf_2x4 = st_bf<2, 4, layout>; +template using st_bf_4x1 = st_bf<4, 1, layout>; +template using st_bf_4x2 = st_bf<4, 2, layout>; +template using st_bf_4x4 = st_bf<4, 4, layout>; +template using st_bf_8x1 = st_bf<8, 1, layout>; + +} diff --git a/triteia/csrc/flash_kittens/types/shared/st_layout.cuh b/triteia/csrc/flash_kittens/types/shared/st_layout.cuh new file mode 100644 index 0000000..27699c6 --- /dev/null +++ b/triteia/csrc/flash_kittens/types/shared/st_layout.cuh @@ -0,0 +1,137 @@ +/** + * @file + * @brief A collection of layouts and indexing patterns for shared memory tiles. + */ + +#pragma once + +#include + +namespace kittens { +namespace ducks { +/** + * @namespace st_layout + * + * @brief A namespace for template metaprogramming with shared tile layouts. + */ +namespace st_layout { + +/** + * @brief A naive row-major layout with no swizzling. row*(#cols)+c + */ +struct naive {}; // swizzling_mode left undefined to cause errors if matrix_descriptor is called. +/** + * @brief A layout for minimal bank conflicts and maximal coalescing. + * + */ +struct swizzle {}; // only defined for x1, x2, x4 tiles. + +/** + * @brief A layout specialized to match both TMA and WGMMA. + * + * Note this layout has worse coalescing than the standard swizzle mode + * for tiles that are a have width that isn't a multiple of 4, + * unless the width is exactly 1 or 2. + */ +struct wgmma_swizzle {}; // only defined for x1, x2, x4 tiles. +/** + * @brief A layout for wgmma with no swizzling. This mode is necessary for the wgmma transpose. + * + * Note, it has worse coalescing and bank conflicts than any other mode. + */ +struct wgmma_interleave { static constexpr int swizzling_mode=0; }; + +/** + * @brief Concept to check if a type is a row-contiguous layout. + */ +template +concept all = ( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v +); + +} +} // namespace ducks + +/** + * @namespace detail + * + * @brief A namespace for internal calculations that really don't need to be exposed. + */ +namespace detail { + +/** + * @brief Struct template to calculate addresses in shared memory tiles + * + * @tparam height The tile height, in subtiles of 16. + * @tparam width The tile width, in subtiles of 16. + * @tparam T The layout type. + * @param r[in] The row position. + * @param c[in] The column position. + * @return The calculated index. + */ +template +struct shared_indexer { + static constexpr int rows = height*16; + static constexpr int cols = width*16; + /** + * @brief Get a memory offset from a row and column index. + */ + __device__ static inline bf16* idx(bf16 *ptr, int r, int c) { // naive row-major index default + return &ptr[r*cols + c]; + } +}; +template +struct shared_indexer { + static constexpr int rows = height*16; + static constexpr int cols = width*16; + static constexpr int swizzle_repeat = (width%4==0) ? 1024 : (width%2==0) ? 512 : 256; + static constexpr int swizzle_shift = (width%4==0) ? 6 : (width%2==0) ? 5 : 4; + __device__ static inline bf16* idx(bf16 *ptr, int r, int c) { // naive row-major index default + const uint64_t addr = (uint64_t)(&ptr[r*cols + c]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (bf16*)(addr ^ swizzle); + } +}; +template +struct shared_indexer { + static constexpr int rows = height*16; + static constexpr int cols = width*16; + static constexpr int swizzle_repeat = (width%4==0) ? 1024 : (width%2==0) ? 512 : 256; + static constexpr int swizzle_shift = (width%4==0) ? 6 : (width%2==0) ? 5 : 4; + static constexpr int subtile_cols = (width%4==0) ? 64 : (width%2==0) ? 32 : 16; + __device__ static inline bf16* idx(bf16 *ptr, int r, int c) { // naive row-major index default + const int outer_idx = c/subtile_cols; + const uint64_t addr = (uint64_t)(&ptr[outer_idx*rows*subtile_cols + r*subtile_cols + c%subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (bf16*)(addr ^ swizzle); + } +}; +template +struct shared_indexer { + static constexpr int rows = height*16; + static constexpr int cols = width*16; + static constexpr int rows_per_core_matrix = 8; + static constexpr int cols_per_core_matrix = 8; + __device__ static inline bf16* idx(bf16 *ptr, int r, int c) { // naive row-major index default + int idx1 = r/rows_per_core_matrix; + int idx2 = c/cols_per_core_matrix; + int idx3 = (r%rows_per_core_matrix); + int idx4 = (c%cols_per_core_matrix); + return &ptr[( + ( + ( + idx1 * (2*width) // width is in units of 16, but we want units of 8 + + idx2 + ) * 8 // * 8 rows per tensormap + + idx3 + ) * 8 // * 8 columns per row + + idx4 + )]; + } +}; + +} +} diff --git a/triteia/csrc/flash_kittens/types/shared/sv.cuh b/triteia/csrc/flash_kittens/types/shared/sv.cuh new file mode 100644 index 0000000..228762e --- /dev/null +++ b/triteia/csrc/flash_kittens/types/shared/sv.cuh @@ -0,0 +1,95 @@ +/** + * @file + * @brief The ThunderKittens shared vector struct. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" + +namespace kittens { + +/* ---------- MAIN VECTOR STRUCT ---------- */ + +namespace ducks { +/** + * @namespace sv + * + * @brief The namespace where concepts and abstract types for shared vectors live. + */ +namespace sv { +/** + * @brief A dummy type used to identify shared vectors. + * + * For a type to quack like an sv, it should define its identifier as ducks::sv::identifier. + * If a type quacks like ducks::sv::identifier, it will be treated as an sv by compiler checks. + */ +struct identifier {}; +} +} + +/** + * @brief Shared vector structure. + * + * @tparam _T The packed data type used for the vector elements. + * @tparam _tiles The size of the tile, in units of TILE_DIM (16). + * + * Shared vectors are used to accumulate and map values across shared tiles. + * Unlike every other structure present in ThunderKittens, these have a simple + * uniform layout which is just an array in memory. EZ! + */ +template +struct KITTENS_DEFAULT_ALIGN sv { + using identifier = ducks::sv::identifier; + using dtype = _T; + + static constexpr int tiles = _tiles; ///< Length in subtiles. + static constexpr int length = tiles * kittens::TILE_DIM; ///< Length in elements. + + dtype data[length]; ///< The actual shared vector data. + + __device__ inline dtype& operator[](size_t idx) { return data[idx]; } + __device__ inline const dtype& operator[](size_t idx) const { return data[idx]; } + + template using subvec = sv; ///< A subvector which allows warpgroups and blocks to work cooperatively. +}; + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace sv { +/** +* @brief Concept for all shared vectors. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as sv::identifier. +*/ +template +concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::sv::identifier + +} // namespace sv +} // namespace ducks + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// vector types +template using sv_bf = sv; +template using sv_fl = sv; + +using sv_bf_1 = sv; +using sv_bf_2 = sv; +using sv_bf_4 = sv; +using sv_bf_8 = sv; +using sv_fl_1 = sv; +using sv_fl_2 = sv; +using sv_fl_4 = sv; +using sv_fl_8 = sv; + +} // namespace kittens \ No newline at end of file diff --git a/triteia/csrc/flash_kittens/types/types.cuh b/triteia/csrc/flash_kittens/types/types.cuh new file mode 100644 index 0000000..27d72b1 --- /dev/null +++ b/triteia/csrc/flash_kittens/types/types.cuh @@ -0,0 +1,51 @@ +/** + * @file + * @brief An aggregate header file for all the register and shared types defined by ThunderKittens. + */ + +#pragma once + +#include "register/register.cuh" +#include "shared/shared.cuh" + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +namespace kittens { + +/** + * @brief Row vector type alias. + * + * This template alias provides a convenient way to refer to the row vector type + * associated with a given class or type `T`. It assumes that the class `T` has + * a nested type named `row_vec`. + * + * @tparam T The class or type for which the row vector type is defined. + * + * Example usage: + * @code + * kittens::row_vec row_vector; + * @endcode + */ +template +using row_vec = T::row_vec; + +/** + * @brief Column vector type alias. + * + * This template alias provides a convenient way to refer to the column vector type + * associated with a given class or type `T`. It assumes that the class `T` has + * a nested type named `col_vec`. + * + * @tparam T The class or type for which the column vector type is defined. + * + * Example usage: + * @code + * kittens::col_vec col_vector; + * @endcode + */ +template +using col_vec = T::col_vec; + +// ^ this code lives here because it applies to both sv and rv types + +} diff --git a/triteia/python/configs/gpus/specs.py b/triteia/python/configs/gpus/specs.py index 9bad193..694a0e4 100644 --- a/triteia/python/configs/gpus/specs.py +++ b/triteia/python/configs/gpus/specs.py @@ -28,7 +28,7 @@ "fp32_tflops": 494.7, } -nvidia_gpus = [nvidia_rtx_3090,nvidia_rtx_a6000, nvidia_gh200_120gb] +nvidia_gpus = [nvidia_rtx_3090, nvidia_rtx_a6000, nvidia_gh200_120gb] def get_gpu_device_info(): diff --git a/triteia/python/nn/linear.py b/triteia/python/nn/linear.py index 9f338fc..7cb8f12 100644 --- a/triteia/python/nn/linear.py +++ b/triteia/python/nn/linear.py @@ -105,6 +105,7 @@ def pack(self, weight, scales, trans=False): s = s.reshape((1, -1)) mask = mask_creator(w.T, n=2, m=4).cuda().bool() + # mask = torch.ones_like(w.t()).cuda().bool() w = torch.round(w / s).int() w += (maxq + 1) // 2 w = torch.clamp(w, 0, maxq) @@ -116,6 +117,8 @@ def pack(self, weight, scales, trans=False): else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] w = mask * w.T + + w = w.contiguous() w, meta = sparse_semi_structured_from_dense_cutlass(w) w = w.t() self.infeatures = self.infeatures // 2 diff --git a/triteia/python/ops/utils/generator.py b/triteia/python/ops/utils/generator.py index 628484e..19557ce 100644 --- a/triteia/python/ops/utils/generator.py +++ b/triteia/python/ops/utils/generator.py @@ -83,3 +83,77 @@ def generate_model_distribution(distribution, num_queries, num_models): probs = np.array(probs) / sum(probs) models = np.random.choice(to_eval_models, num_queries, p=probs) return torch.tensor(models, dtype=torch.int32, device="cuda") + + +def fp16_to_sparse(weights, scale, device="cuda"): + from triteia.python.nn.linear import sparse_low_precision_linear + + k, m = weights.shape + k_sp = k // 2 + s = scale + layer = sparse_low_precision_linear(k, m, groupsize=-1) + groupsize = k + layer.groupsize = groupsize + layer.B = torch.empty((k_sp // 16, m * 16 // 8), dtype=torch.int, device=device) + layer.meta = torch.empty((m, k // 16), dtype=torch.int16, device=device) + layer.s = torch.empty( + (k_sp // (groupsize // 2), m), dtype=torch.half, device=device + ) + layer.pack(weights, scale, True) + q = layer.B + s = layer.s + meta = layer.meta + return weights, q, s, meta, layer.meta_raw + + +@torch.no_grad() +def torch_weight_to_sparse_marlin(weight, scale, tp_size=1, chunk_by="column"): + """ + Args: + weight: torch.Tensor of shape (in_features, out_features) + scale: torch.Tensor of shape (1, out_features) + tp_size: tensor parallelism size + chunk_by: "column" or "row" + """ + from triteia.python.nn.linear import sparse_low_precision_linear + + assert chunk_by in ["column", "row"], "chunk_by must be either 'column' or 'row'" + assert weight.dim() == 2, "weight must be a 2D tensor" + assert weight.size(0) % tp_size == 0, "out_features must be divisible by tp_size" + assert weight.size(1) == scale.size( + 1 + ), "out_features of weight and scale must match" + + if not weight.is_contiguous(): + weight = weight.contiguous() + if not scale.is_contiguous(): + scale = scale.contiguous() + + qweights, scales, metas = [], [], [] + for i in range(tp_size): + if chunk_by == "column": + tp_weight = weight[ + :, i * weight.size(1) // tp_size : (i + 1) * weight.size(1) // tp_size + ] + tp_scales = scale[ + :, i * scale.size(1) // tp_size : (i + 1) * scale.size(1) // tp_size + ] + elif chunk_by == "row": + tp_weight = weight[ + i * weight.size(0) // tp_size : (i + 1) * weight.size(0) // tp_size, : + ] + tp_scales = scale + layer = sparse_low_precision_linear( + infeatures=tp_weight.size(0), outfeatures=tp_weight.size(1), groupsize=-1 + ) + k, m = tp_weight.size(0), tp_weight.size(1) + k_sp = k // 2 + layer.groupsize = k + layer.B = torch.empty((k_sp // 16, m * 16 // 8), dtype=torch.int) + layer.meta = torch.empty((m, k // 16), dtype=torch.int16) + layer.s = torch.empty((k_sp // (k // 2), m), dtype=torch.half) + layer.pack(tp_weight, scales=tp_scales, trans=True) + qweights.append(layer.B.cuda().contiguous()) + scales.append(layer.s.cuda().contiguous()) + metas.append(layer.meta.cuda().contiguous()) + return qweights, scales, metas diff --git a/triteia/python/ops/utils/sparsity.py b/triteia/python/ops/utils/sparsity.py index 113c082..6c07398 100644 --- a/triteia/python/ops/utils/sparsity.py +++ b/triteia/python/ops/utils/sparsity.py @@ -281,14 +281,12 @@ def sparse_semi_structured_from_dense_cutlass(dense): | (meta_n[:, :, 6] << 24) | (meta_n[:, :, 7] << 28) ) - # Reorder meta tensor elements. meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device ) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - return (sparse, meta_reordered.view(m, meta_ncols)) @@ -396,5 +394,20 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): return dense.view(m, 2 * k) +def reorder_meta(meta): + m, k = meta.size(0), meta.size(1) * 16 + ksparse = 4 + meta_dtype = torch.int16 + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + # 256 // (2 * 4) = 32 + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + meta_reordered = meta.new_empty((m * meta_ncols,)) + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, "cuda" + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + return meta_reordered.view(m, meta_ncols) + + _perm, _scale_perm, _scale_perm_single = _get_perms() _perm_2_4, _scale_perm_2_4, _scale_perm_single_2_4 = _get_perms_2_4() diff --git a/triteia/python/utils/benchmark.py b/triteia/python/utils/benchmark.py index 5a0a6ee..f7c23db 100644 --- a/triteia/python/utils/benchmark.py +++ b/triteia/python/utils/benchmark.py @@ -6,6 +6,7 @@ from rich.table import Table from triteia.python.configs.gpus.specs import get_gpu_device_info + def timing_function(func, flops_func, kwargs, repeats=1): func_args_names = inspect.getfullargspec(func).args func_args = {arg: kwargs[arg] for arg in func_args_names if arg in kwargs} @@ -67,56 +68,53 @@ def print_results_table(title, results): console = Console() console.print(table) -def export_benchmark_results(results, filepath:str): + +def export_benchmark_results(results, filepath: str): gpu_specs = get_gpu_device_info() config_results = [] for result in results: for res in result: # ignore args if it is torch tensor - config = {k: v for k, v in res['args'].items() if not isinstance(v, torch.Tensor)} - del res['args'] - del res['output'] - config_results.append({ - 'config': config, - **res - }) - with open(filepath, 'w') as f: - json.dump({ - 'gpu_specs': gpu_specs, - 'results': config_results - }, f, indent=4) + config = { + k: v for k, v in res["args"].items() if not isinstance(v, torch.Tensor) + } + del res["args"] + del res["output"] + config_results.append({"config": config, **res}) + with open(filepath, "w") as f: + json.dump({"gpu_specs": gpu_specs, "results": config_results}, f, indent=4) + def format_benchmark_results(filepath: str): - with open(filepath, 'r') as f: + with open(filepath, "r") as f: data = json.load(f) - gpu_specs = data['gpu_specs'] - results = data['results'] + gpu_specs = data["gpu_specs"] + results = data["results"] df_results = [] parsed_results = [] configs = [] for result in results: - config = result['config'].copy() - del result['config'] + config = result["config"].copy() + del result["config"] res = result.copy() - func_name = res['func_name'] - del res['func_name'] + func_name = res["func_name"] + del res["func_name"] res = {f"{func_name}_{k}": v for k, v in res.items()} - parsed_results.append({ - "config": config, - **res - }) + parsed_results.append({"config": config, **res}) configs.append(config) for config in configs: - res = [d for d in parsed_results if d['config'] == config] - res = [{k: v for k, v in d.items() if k != 'config'} for d in res] + res = [d for d in parsed_results if d["config"] == config] + res = [{k: v for k, v in d.items() if k != "config"} for d in res] results = {} for r in res: results.update(r) - df_results.append({ - **config, - **results, - }) + df_results.append( + { + **config, + **results, + } + ) df = pd.DataFrame(df_results) # deduplicate rows df = df.drop_duplicates() - return gpu_specs, df \ No newline at end of file + return gpu_specs, df diff --git a/triteia/python/utils/io.py b/triteia/python/utils/io.py index dc1cb73..2dffe93 100644 --- a/triteia/python/utils/io.py +++ b/triteia/python/utils/io.py @@ -6,7 +6,7 @@ def save_tensors(tensors, path): tensors[key] = tensors[key].contiguous() save_file(tensors, path) -def read_tensors(path, prefix=None, device='cpu'): +def read_tensors(path, prefix=None, device="cpu"): tensors = {} with st.safe_open(path, framework="pt", device=device) as f: for key in f.keys(): @@ -16,4 +16,5 @@ def read_tensors(path, prefix=None, device='cpu'): if key.startswith(prefix): module_name = key.removeprefix(prefix + ".") tensors[module_name] = f.get_tensor(key) - return tensors \ No newline at end of file + return tensors + diff --git a/triteia/python/utils/quant_utils.py b/triteia/python/utils/quant_utils.py index 9456bf9..f47116c 100644 --- a/triteia/python/utils/quant_utils.py +++ b/triteia/python/utils/quant_utils.py @@ -1,23 +1,23 @@ # adapted from https://github.com/IST-DASLab/marlin/blob/2e87035acf1b117aaf2c840c32b6a2b0a6c6ca4a/conversion/convert.py import torch +import numpy as np + @torch.no_grad() def unpack_4bit_to_32bit_signed(qweight, qzeros): # Unpack 4-bit values and interpret them as signed integers unpacked_weights = torch.zeros( - (qweight.shape[0]*8, qweight.shape[1]), + (qweight.shape[0] * 8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, - requires_grad=False + requires_grad=False, ) unpacked_zeros = torch.zeros( - (qzeros.shape[0], qzeros.shape[1]*8), - dtype=torch.int8, - device=qzeros.device, - requires_grad=False + (qzeros.shape[0], qzeros.shape[1] * 8), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, ) - - for row in range(unpacked_weights.shape[0]): i = row % 8 unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF @@ -33,6 +33,7 @@ def unpack_4bit_to_32bit_signed(qweight, qzeros): ) return unpacked_weights, unpacked_zeros + 1 + @torch.no_grad() def dequantize_weight(qweight, qzeros, scales): unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) @@ -42,6 +43,7 @@ def dequantize_weight(qweight, qzeros, scales): unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales return unpacked_qweight.T + @torch.no_grad() def gptq_unpack(bits, qweight, qzeros, scales, group_size=-1): if group_size == -1: @@ -71,4 +73,39 @@ def gptq_unpack(bits, qweight, qzeros, scales, group_size=-1): weight = weight.reshape(-1, group_size, weight.shape[2]) weight = scales * (weight - zeros) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - return weight \ No newline at end of file + return weight + + +def unpack_2bit_from_16bit(tensor): + unpacked_values = [] + + # Define a mask for 2 bits + mask = 0b11 # This is binary for '11', which is 3 in decimal + + # Process each element in the tensor + for value in tensor: + # Extract 8 values of 2 bits each + for i in range(8): # 8 values of 2 bits each in a 16-bit number + # Shift right by i*2 positions and apply mask + unpacked_value = (value >> (i * 2)) & mask + unpacked_values.append(unpacked_value) + + return np.array(unpacked_values) + + +def pack_2bit_to_16bit(values): + if len(values) % 8 != 0: + raise ValueError("The number of values must be a multiple of 8.") + + # Create an empty list to store the packed int16 values + packed_tensor = [] + + # Process each group of 8 values + for i in range(0, len(values), 8): + packed_value = 0 + for j in range(8): + # Shift the value to its correct position and combine it with the previous values + packed_value |= (values[i + j] & 0b11) << (j * 2) + packed_tensor.append(packed_value) + + return torch.tensor(np.array(packed_tensor, dtype=np.int16)) diff --git a/triteia/tools/converters/convert_deltazip.py b/triteia/tools/converters/convert_deltazip.py index 5087eb9..08c5bcf 100644 --- a/triteia/tools/converters/convert_deltazip.py +++ b/triteia/tools/converters/convert_deltazip.py @@ -6,72 +6,25 @@ from triteia.python.utils.io import save_tensors from triteia.python.utils.quant_utils import dequantize_weight from triteia.python.utils.compressor import LosslessCompressor -from triteia.python.configs.models.llama import row_chunking_modules, uncompressed_row_chunking_modules, pack_modules +from triteia.python.configs.models.llama import ( + row_chunking_modules, + uncompressed_row_chunking_modules, + pack_modules, +) from triteia.python.nn.linear import sparse_low_precision_linear +from triteia.python.ops.utils.generator import torch_weight_to_sparse_marlin -@torch.no_grad() -def torch_weight_to_sparse_marlin(weight, scale, tp_size=1, chunk_by="column"): - """ - Args: - weight: torch.Tensor of shape (in_features, out_features) - scale: torch.Tensor of shape (1, out_features) - tp_size: tensor parallelism size - chunk_by: "column" or "row" - """ - assert chunk_by in ["column", "row"], "chunk_by must be either 'column' or 'row'" - assert weight.dim() == 2, "weight must be a 2D tensor" - assert weight.size(0) % tp_size == 0, "out_features must be divisible by tp_size" - assert weight.size(1) == scale.size(1), "out_features of weight and scale must match" - - if not weight.is_contiguous(): - weight = weight.contiguous() - if not scale.is_contiguous(): - scale = scale.contiguous() - - qweights, scales,metas = [], [], [] - for i in range(tp_size): - if chunk_by == "column": - tp_weight = weight[ - :, - i * weight.size(1) // tp_size: (i + 1) * weight.size(1) // tp_size - ] - tp_scales = scale[ - :, - i * scale.size(1) // tp_size: (i + 1) * scale.size(1) // tp_size - ] - elif chunk_by == "row": - tp_weight = weight[ - i * weight.size(0) // tp_size: (i + 1) * weight.size(0) // tp_size, - : - ] - tp_scales = scale - layer = sparse_low_precision_linear( - infeatures=tp_weight.size(0), - outfeatures=tp_weight.size(1), - groupsize=-1 - ) - k, m = tp_weight.size(0), tp_weight.size(1) - k_sp = k // 2 - layer.groupsize = k - layer.B = torch.empty((k_sp // 16, m * 16 // 8), dtype=torch.int) - layer.meta = torch.empty((m, k // 16), dtype=torch.int16) - layer.s = torch.empty((k_sp // (k // 2), m), dtype=torch.half) - layer.pack(tp_weight, scales=tp_scales, trans=True) - qweights.append(layer.B) - scales.append(layer.s) - metas.append(layer.meta) - return qweights, scales, metas @torch.no_grad() def convert_model(args, verbose=True): DEV = "cuda:0" - + new_tensors = {} tensors = {} packed_tensors = {} dequantized_tensors = {} remaining_keys = [] - + with st.safe_open(args.ckpt, framework="torch", device="cuda:0") as f: keys = f.keys() remaining_keys = list(f.keys()) @@ -81,7 +34,7 @@ def convert_model(args, verbose=True): if args.lossless: tensors_dtypes = json.loads(metadata["dtype"]) tensors_shapes = json.loads(metadata["shape"]) - + if args.lossless: print(f"Decompressing from lossless format...") with cp.cuda.Device(0): @@ -102,31 +55,36 @@ def convert_model(args, verbose=True): pbar = tqdm(quantized_modules, position=0, leave=True) print("Dequantizing weights...") for module in pbar: - dequantized_weight = dequantize_weight( - tensors[module + ".qweight"], - tensors[module + ".qzeros"], - tensors[module + ".scales"], - ).to(torch.float16).t().cpu() + dequantized_weight = ( + dequantize_weight( + tensors[module + ".qweight"], + tensors[module + ".qzeros"], + tensors[module + ".scales"], + ) + .to(torch.float16) + .t() + .cpu() + ) scales = tensors[module + ".scales"] dequantized_tensors[module] = (dequantized_weight, scales) remaining_keys.remove(module + ".qweight") remaining_keys.remove(module + ".qzeros") remaining_keys.remove(module + ".scales") remaining_keys.remove(module + ".g_idx") - + # now start to pack weights together pack_plan = {} for module in quantized_modules: if any([key in module for key in pack_modules.keys()]): source_layer = module.rsplit(".", 2)[0] - source_module = module.replace(source_layer+".", "") + source_module = module.replace(source_layer + ".", "") target_module = pack_modules[source_module] target_idx = int(target_module.split(":")[1]) target_module = source_layer + "." + target_module.split(":")[0] if target_module not in pack_plan: pack_plan[target_module] = [] pack_plan[target_module].append((module, target_idx)) - + elif any([key in module for key in row_chunking_modules]): qweights, scales, metas = torch_weight_to_sparse_marlin( dequantized_tensors[module][0].to(DEV), @@ -142,7 +100,6 @@ def convert_model(args, verbose=True): key_weights = [] key_scales = [] plan = sorted(pack_plan[key], key=lambda x: x[1]) - print(f"Plan for {key}: {plan}") for module, idx in plan: weight, scales = dequantized_tensors[module] assert weight.shape[1] == scales.shape[1] @@ -154,7 +111,7 @@ def convert_model(args, verbose=True): torch.cuda.synchronize() del dequantized_tensors[module] torch.cuda.empty_cache() - + qweights, scales, metas = torch_weight_to_sparse_marlin( packed_tensors[key][0].to(DEV), packed_tensors[key][1].to(DEV), @@ -165,7 +122,7 @@ def convert_model(args, verbose=True): new_tensors[key + f".{idx}.qweight"] = qweight new_tensors[key + f".{idx}.scales"] = scales new_tensors[key + f".{idx}.meta"] = meta - + # # now processing remaining keys for module in remaining_keys: if any([key in module for key in uncompressed_row_chunking_modules]): @@ -173,11 +130,14 @@ def convert_model(args, verbose=True): module_name = module.removesuffix(".weight") num_rows = weight.shape[0] for i in range(args.tp_size): - tp_weight = weight[i * num_rows // args.tp_size: (i + 1) * num_rows // args.tp_size, :] + tp_weight = weight[ + i * num_rows // args.tp_size : (i + 1) * num_rows // args.tp_size, : + ] new_tensors[module_name + f".{i}.weight"] = tp_weight - + return new_tensors - + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt", type=str) @@ -186,7 +146,7 @@ def convert_model(args, verbose=True): parser.add_argument("--lossless", action="store_true") parser.add_argument("--pack", action="store_true") args = parser.parse_args() - + print("Converting model...") new_tensors = convert_model(args, verbose=True) - save_tensors(new_tensors, args.save_path) \ No newline at end of file + # save_tensors(new_tensors, args.save_path) diff --git a/triteia/tools/export_benchmark.py b/triteia/tools/export_benchmark.py index 7e0c4bb..1793086 100644 --- a/triteia/tools/export_benchmark.py +++ b/triteia/tools/export_benchmark.py @@ -1,6 +1,7 @@ import os from triteia.python.utils.benchmark import format_benchmark_results + def export(args): gpu_spec, df = format_benchmark_results(args.in_path) # get filename from path @@ -8,10 +9,14 @@ def export(args): out_path = os.path.join(args.out_path, f"{filename}_{gpu_spec['name']}.csv") df.to_csv(out_path, index=False) + if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Export benchmark results') - parser.add_argument('--in-path', type=str, help='Filepath to the benchmark results') - parser.add_argument('--out-path', type=str, help='Filepath to the exported benchmark results') + + parser = argparse.ArgumentParser(description="Export benchmark results") + parser.add_argument("--in-path", type=str, help="Filepath to the benchmark results") + parser.add_argument( + "--out-path", type=str, help="Filepath to the exported benchmark results" + ) args = parser.parse_args() - export(args) \ No newline at end of file + export(args) diff --git a/triteia/tools/verify_weights.py b/triteia/tools/verify_weights.py index d1e6c5c..f0b1994 100644 --- a/triteia/tools/verify_weights.py +++ b/triteia/tools/verify_weights.py @@ -3,10 +3,11 @@ from triteia.python.ops import matmul_4bit_2_4 from triteia.python.ops.utils.generator import generate_model_distribution + def check_tp_group_equal(weights, reference_weights): modules = set() tp_groups = set() - + for key in weights.keys(): # separate by . # last element - component, second last - tp id, others - module name @@ -14,22 +15,35 @@ def check_tp_group_equal(weights, reference_weights): tp_groups.add(tp_group) module_name = ".".join(key.split(".")[:-2]) modules.add(module_name) - + for module in modules: - tp_groups_in_modules = max([int(key.split(".")[-2]) for key in weights.keys() if module in key]) + 1 - components_in_modules = [key.split(".")[-1] for key in weights.keys() if module in key] + tp_groups_in_modules = ( + max([int(key.split(".")[-2]) for key in weights.keys() if module in key]) + + 1 + ) + components_in_modules = [ + key.split(".")[-1] for key in weights.keys() if module in key + ] for component in components_in_modules: - components_across_tp = [value for key, value in weights.items() if module in key and component in key] + components_across_tp = [ + value + for key, value in weights.items() + if module in key and component in key + ] # there should be at most tp_groups_in_modules tensors for each component - assert len(components_across_tp) == tp_groups_in_modules, f"Module {module} has {len(components_across_tp)} components for {component}" + assert ( + len(components_across_tp) == tp_groups_in_modules + ), f"Module {module} has {len(components_across_tp)} components for {component}" # check if there are same tensors for each component for i in range(1, len(components_across_tp)): - if torch.equal(components_across_tp[i-1], components_across_tp[i]): - print(f"Module {module} has same tensors for {component} in tp group {i-1} and {i}") - + if torch.equal(components_across_tp[i - 1], components_across_tp[i]): + print( + f"Module {module} has same tensors for {component} in tp group {i-1} and {i}" + ) + + def check_output(weights, reference_weights, module_name): target_weight = {key: value for key, value in weights.items() if module_name in key} - reference_weight = {key: value for key, value in reference_weights.items() if module_name in key} tp_groups = set() for key in weights.keys(): # separate by . @@ -39,10 +53,13 @@ def check_output(weights, reference_weights, module_name): reference_qweight = reference_weights[f"{module_name}.0.qweight"] reference_meta = reference_weights[f"{module_name}.0.meta"] reference_scale = reference_weights[f"{module_name}.0.scales"] - - nr = 10 - x = torch.randn((nr, 32 * reference_qweight.size(0)), dtype=torch.float16, device='cuda') - reference_output = matmul_4bit_2_4(reference_qweight, x, reference_meta, reference_scale) + nr = 128 + x = torch.randn( + (nr, 32 * reference_qweight.size(0)), dtype=torch.float16, device="cuda" + ) + reference_output = matmul_4bit_2_4( + reference_qweight, x, reference_meta, reference_scale + ) tp_outputs = [] tp_groups = sorted(list(tp_groups)) for tp in tp_groups: @@ -52,27 +69,46 @@ def check_output(weights, reference_weights, module_name): output = matmul_4bit_2_4(qweight, x, meta, scale) tp_outputs.append(output) tp_output = torch.cat(tp_outputs, dim=1) - - print(f"reference_output: {reference_output.shape}, tp_output: {tp_output.shape}") - print(f"first half reference_out: \n{reference_output[:, :reference_output.size(1)//2]}\nfirst half tp_out: \n{tp_output[:, :tp_output.size(1)//2]}") - - print(f"second half reference_out: \n{reference_output[:, reference_output.size(1)//2:]}\nsecond half tp_out: \n{tp_output[:, tp_output.size(1)//2:]}") - + max_diff = torch.max(torch.abs(reference_output - tp_output)) / torch.mean( + torch.abs(reference_output) + ) print(f"reference_output: \n{reference_output}\ntp_output: \n{tp_output}") - - print(f"max diff: {torch.max(torch.abs(reference_output - tp_output))}") - + print(f"diff: \n{reference_output-tp_output}") + print(f"max_diff: {max_diff}") + + return max_diff + + +def check_output_all_modules(weights, reference_weights): + modules = set() + for key in weights.keys(): + # separate by . + # last element - component, second last - tp id, others - module name + module_name = ".".join(key.split(".")[:-2]) + modules.add(module_name) + for module in modules: + if "qkv_proj" in module or "gate_up_proj" in module: + max_diff = check_output(weights, reference_weights, module) + print(f"Max difference for {module}: {max_diff}") + + def verify(args): print(args) - weights = read_tensors(args.input, device='cuda') - reference_weights = read_tensors(args.reference_input, device='cuda') + weights = read_tensors(args.input, device="cuda") + reference_weights = read_tensors(args.reference_input, device="cuda") + # check_output_all_modules(weights, reference_weights) check_output(weights, reference_weights, "model.layers.9.self_attn.qkv_proj") # check_tp_group_equal(weights, reference_weights) -if __name__=="__main__": + +if __name__ == "__main__": import argparse + + # set seed parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, help="Path to the input file") - parser.add_argument("--reference-input", default="", type=str, help="Path to the input file") + parser.add_argument( + "--reference-input", default="", type=str, help="Path to the input file" + ) args = parser.parse_args() - verify(args) \ No newline at end of file + verify(args)