Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W8a8 fp8 disapatch improvement #1

Merged
merged 16 commits into from
Jan 21, 2025
10,399 changes: 10,399 additions & 0 deletions e -i HEAD~3q:q
yych0745 marked this conversation as resolved.
Show resolved Hide resolved
yych0745 marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_library(_kernels SHARED
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/int8_gemm_kernel.cu
src/sgl-kernel/csrc/fp8_gemm_kernel.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)

Expand Down
55 changes: 55 additions & 0 deletions sgl-kernel/benchmark/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import torch.nn.functional as F
import triton

from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
import time

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm-fp8-fp16", "vllm-fp8-bf16", "sglang-fp8-fp16", "sglang-fp8-bf16"],
line_names=["vllm-fp8-fp16", "vllm-fp8-bf16", "sglang-fp8-fp16", "sglang-fp8-bf16"],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)

def benchmark(batch_size, provider):
M, N, K = batch_size, 4096, 8192
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]

dtype = torch.float16 if "fp16" in provider else torch.bfloat16

if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype
),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None),
quantiles=quantiles,
)

gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path="bench_fp8_res")
Empty file added sgl-kernel/outp
yych0745 marked this conversation as resolved.
Show resolved Hide resolved
yych0745 marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
11 changes: 11 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
import sys
import multiprocessing
yych0745 marked this conversation as resolved.
Show resolved Hide resolved

root = Path(__file__).parent.resolve()

Expand All @@ -23,23 +26,30 @@ def update_wheel_platform_tag():


cutlass = root / "3rdparty" / "cutlass"

include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "csrc",
]

nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90a,code=sm_90a",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]


cxx_flags = ["-O3"]


libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
Expand All @@ -50,6 +60,7 @@ def update_wheel_platform_tag():
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
include_dirs=include_dirs,
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
custom_reduce,
init_custom_reduce,
int8_scaled_mm,
fp8_scaled_mm,
moe_align_block_size,
)

Expand All @@ -12,4 +13,5 @@
"custom_dispose",
"custom_reduce",
"int8_scaled_mm",
"fp8_scaled_mm",
]
Loading