From 0c55c559b82259ac4efe1d60b40c5f59d299bbf0 Mon Sep 17 00:00:00 2001 From: yych0745 <1398089567@qq.com> Date: Thu, 23 Jan 2025 14:04:46 +0800 Subject: [PATCH 1/2] change bench_fp8 --- sgl-kernel/benchmark/bench_fp8_gemm.py | 145 +++++++++++++++++++------ sgl-kernel/benchmark/weights_shape.py | 59 ++++++++++ 2 files changed, 173 insertions(+), 31 deletions(-) create mode 100644 sgl-kernel/benchmark/weights_shape.py diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index e68695a3f39..cfcdb2decd0 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -1,37 +1,19 @@ import torch import torch.nn.functional as F import triton -from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +import argparse from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from weights_shape import WEIGHT_SHAPES +from typing import Callable, Iterable, List, Tuple +from torch.utils.benchmark import Measurement as TMeasurement +import itertools +import copy +import time +import pickle as pkl - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], - x_log=False, - line_arg="provider", - line_vals=[ - "vllm-fp8-fp16", - "vllm-fp8-bf16", - "sglang-fp8-fp16", - "sglang-fp8-bf16", - ], - line_names=[ - "vllm-fp8-fp16", - "vllm-fp8-bf16", - "sglang-fp8-fp16", - "sglang-fp8-bf16", - ], - styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], - ylabel="GB/s", - plot_name="fp8 scaled matmul", - args={}, - ) -) -def benchmark(batch_size, provider): - M, N, K = batch_size, 4096, 8192 +def bench(dtype, M, N, K, provider): a = torch.ones((M, K), device="cuda") * 5.0 b = torch.ones((N, K), device="cuda") * 5.0 scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) @@ -41,8 +23,6 @@ def benchmark(batch_size, provider): b_fp8 = b_fp8.t() quantiles = [0.5, 0.2, 0.8] - dtype = torch.float16 if "fp16" in provider else torch.bfloat16 - if "vllm-fp8" in provider: ms, min_ms, max_ms = triton.testing.do_bench( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), @@ -59,5 +39,108 @@ def benchmark(batch_size, provider): gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + for provider in ["vllm-fp8", "sglang-fp8"]: + gbps = bench(dtype, m, k, n, provider) + results.append((provider, m, k, n, *gbps)) + + return results + +def run_model_bench(args): + print("Benchmarking models:") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print(f"{'Provider':<15} {'M':<10} {'K':<10} {'N':<10} {'GB/s':<10} {'Max GB/s':<10} {'Min GB/s':<10}") + print("=" * 70) + + for provider, m, k, n, gbps, max_gbps, min_gbps in data: + print(f"{provider:<15} {m:<10} {k:<10} {n:<10} {gbps:<10.2f} {max_gbps:<10.2f} {min_gbps:<10.2f}") + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + +def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float16 + raise ValueError(f"unsupported dtype: {dt}") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=""" +Benchmark Cutlass GEMM. + +To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + +To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + +To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + +Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """ + ) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['fp8']") + + subparsers = parser.add_subparsers(dest="cmd") + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark") + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes") + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=[16], + help="List of batch sizes") -benchmark.run(print_data=True, show_plots=True, save_path="bench_fp8_res") + args = parser.parse_args() + + if args.cmd == "model_bench": + run_model_bench(args) diff --git a/sgl-kernel/benchmark/weights_shape.py b/sgl-kernel/benchmark/weights_shape.py new file mode 100644 index 00000000000..dcd41d8aba4 --- /dev/null +++ b/sgl-kernel/benchmark/weights_shape.py @@ -0,0 +1,59 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} \ No newline at end of file From 9d1d10c827f9edb89930b05decd473d0dcab4ea1 Mon Sep 17 00:00:00 2001 From: yych0745 <1398089567@qq.com> Date: Thu, 23 Jan 2025 14:32:59 +0800 Subject: [PATCH 2/2] clean code --- sgl-kernel/benchmark/bench_fp8_gemm.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index cfcdb2decd0..343c3714f92 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -100,20 +100,13 @@ def to_torch_dtype(dt): if __name__ == '__main__': parser = argparse.ArgumentParser( description=""" -Benchmark Cutlass GEMM. + Benchmark Cutlass GEMM. + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 -To run square GEMMs: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 - -To run constant N and K and sweep M: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 - -To run dimensions from a model: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 - -Output: - - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - """ + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """ ) parser.add_argument("--dtype",