Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more mnk for benchmark #7

Open
wants to merge 2 commits into
base: main_w8a8_fp8
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 107 additions & 31 deletions sgl-kernel/benchmark/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,19 @@
import torch
import torch.nn.functional as F
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
import argparse
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from weights_shape import WEIGHT_SHAPES
from typing import Callable, Iterable, List, Tuple
from torch.utils.benchmark import Measurement as TMeasurement
import itertools
import copy
import time
import pickle as pkl


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider):
M, N, K = batch_size, 4096, 8192
def bench(dtype, M, N, K, provider):
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
Expand All @@ -41,8 +23,6 @@ def benchmark(batch_size, provider):
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]

dtype = torch.float16 if "fp16" in provider else torch.bfloat16

if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
Expand All @@ -59,5 +39,101 @@ def benchmark(batch_size, provider):
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
for provider in ["vllm-fp8", "sglang-fp8"]:
gbps = bench(dtype, m, k, n, provider)
results.append((provider, m, k, n, *gbps))

return results

def run_model_bench(args):
print("Benchmarking models:")

def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs

model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))

data = run(args.dtype, MKNs)
model_bench_data.append(data)

# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print(f"{'Provider':<15} {'M':<10} {'K':<10} {'N':<10} {'GB/s':<10} {'Max GB/s':<10} {'Min GB/s':<10}")
print("=" * 70)

for provider, m, k, n, gbps, max_gbps, min_gbps in data:
print(f"{provider:<15} {m:<10} {k:<10} {n:<10} {gbps:<10.2f} {max_gbps:<10.2f} {min_gbps:<10.2f}")
timestamp = int(time.time())

all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)

def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float16
raise ValueError(f"unsupported dtype: {dt}")

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1

Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
)

parser.add_argument("--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['fp8']")

subparsers = parser.add_subparsers(dest="cmd")

model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark")
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes")
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=[16],
help="List of batch sizes")

benchmark.run(print_data=True, show_plots=True, save_path="bench_fp8_res")
args = parser.parse_args()

if args.cmd == "model_bench":
run_model_bench(args)
59 changes: 59 additions & 0 deletions sgl-kernel/benchmark/weights_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
Loading