From 40336791c9a9ca922dc002f985fd32b74836b851 Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Wed, 3 Jul 2024 15:36:23 +0200 Subject: [PATCH] bench bmm --- benchmarks/bench_bmm.py | 34 +++++++++++++++++++++++++++++++ tests/ops/test_bmm.py | 10 ++++++--- triteia/python/utils/benchmark.py | 1 + 3 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 benchmarks/bench_bmm.py diff --git a/benchmarks/bench_bmm.py b/benchmarks/bench_bmm.py new file mode 100644 index 0000000..5f1e06c --- /dev/null +++ b/benchmarks/bench_bmm.py @@ -0,0 +1,34 @@ +import torch +from triteia.python.ops import bmm_4bit_2_4_forloop, gen_batched_sparse_quant4_NT, bmm_4bit_2_4 +from triteia.python.utils import timing_function, print_results_table +from triteia.python.configs.models.llama import llama_shapes + +flops_func = lambda b, m, n, k: 2 * b * m * n * k + +def benchmark(b, m,n,k, dev="cuda", groupsize=-1): + x = torch.randn((b, 1, k), dtype=torch.float16, device=dev) + weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( + b, m, k, groupsize=groupsize, device=dev + ) + def fp16_func(x, weight_ref): + return torch.matmul(x, weight_ref) + def w4_2_4_forloop(qweight, x, meta, scale): + return bmm_4bit_2_4_forloop(qweight, x, meta, scale) + def w4_2_4_native(qweight, x, meta, scale): + return bmm_4bit_2_4(qweight, x, meta, scale) + + w4_2_4_forloop_result = timing_function( + w4_2_4_forloop, flops_func, kwargs={"b": b, "m": m, "n": n, "k": k, "qweight": qweight, "x": x, "meta": meta, "scale": scale}, repeats=5 + ) + fp16_result = timing_function( + fp16_func, flops_func, kwargs={"b": b, "m": m, "n": n, "k": k, "x": x, "weight_ref": weight_ref}, repeats=5 + ) + w4_2_4_native_result = timing_function( + w4_2_4_native, flops_func, kwargs={"b": b, "m": m, "n": n, "k": k, "qweight": qweight, "x": x, "meta": meta, "scale": scale}, repeats=5 + ) + results = [fp16_result, w4_2_4_forloop_result, w4_2_4_native_result] + print_results_table(f"bmm b={b},m={m},n={n},k={k}", results) + +if __name__ == "__main__": + benchmark(2, 256, 32, 256) + benchmark(8, 4096, 32, 4096) \ No newline at end of file diff --git a/tests/ops/test_bmm.py b/tests/ops/test_bmm.py index 8c7fd48..ed07dde 100644 --- a/tests/ops/test_bmm.py +++ b/tests/ops/test_bmm.py @@ -38,12 +38,16 @@ def test_tiny(self): self.run_problem(16, 256, 16, 512, groupsize=-1) self.run_problem(16, 512, 16, 256, groupsize=-1) self.run_problem(8, 256, 16, 256, groupsize=-1) + self.run_problem(8, 512, 16, 256, groupsize=-1) + self.run_problem(8, 256, 16, 512, groupsize=-1) + self.run_problem(8, 512, 16, 256, groupsize=-1) self.run_problem(4, 512, 16, 512, groupsize=-1) self.run_problem(4, 256, 16, 512, groupsize=-1) - self.run_problem(8, 512, 16, 256, groupsize=-1) - + self.run_problem(4, 256, 16, 512, groupsize=-1) + self.run_problem(4, 512, 16, 256, groupsize=-1) + def test_llama(self): - bszs = [4, 8, 16] + bszs = [2,4,5,6,7, 8,9, 16] for _, layers in llama_shapes.items(): for layer in layers: for bsz in bszs: diff --git a/triteia/python/utils/benchmark.py b/triteia/python/utils/benchmark.py index cd6ad21..26ea258 100644 --- a/triteia/python/utils/benchmark.py +++ b/triteia/python/utils/benchmark.py @@ -28,6 +28,7 @@ def timing_function(func, flops_func, kwargs, repeats=1): if flops_func: total_flops = flops_func(**flops_func_args) # FLOPS perf_flops = total_flops/elapsed # FlOPS/ms + # total_tflops = total_flops/1e12 # TFLOPS if gpu_info: mfu = 100 * perf_flops/1e9/gpu_info["fp16_tflops"]