Skip to content

Commit

Permalink
vllm Fp8w8a8 block (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai authored Feb 15, 2025
1 parent 5ab69f2 commit 22da88c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
52 changes: 52 additions & 0 deletions lightllm/common/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .quantize_method import QuantizationMethod
from .registry import QUANTMETHODS
import torch.nn.functional as F
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul

try:
HAS_VLLM = True
Expand Down Expand Up @@ -150,3 +152,53 @@ def apply_pingpong_fp8(
from fp8_pingpong_gemm import cutlass_scaled_mm

return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)


@QUANTMETHODS.register(["vllm-fp8w8a8-b128"])
class vLLMFP8w8a8B128QuantizationMethod(vLLMBaseQuantizationMethod):
def __init__(self):
super().__init__()
self.block_size = 128

def quantize(self, weight: torch.Tensor):
if self.is_moe:
return self.quantize_moe(weight)
qweight, weight_scale = ops.scaled_fp8_quant(
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
)
return qweight.transpose(0, 1), weight_scale

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
qweight, weight_scale, input_scale = weights
m, k = input_tensor.shape
n = weights[0].shape[1]
if input_scale is None:
input_scale = self.cache_manager.alloc_tensor(
(m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False
)
qinput_tensor = self.cache_manager.alloc_tensor(
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
)
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale)
if out is None:
if use_custom_tensor_mananger:
out = self.cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
else:
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
if n % 128 != 0:
w8a8_block_fp8_matmul(
qinput_tensor,
qweight,
input_scale,
weight_scale,
out,
(self.block_size, self.block_size),
dtype=input_tensor.dtype,
)
else:
qweight = qweight.t().contiguous().t()
input_scale = input_scale.t().contiguous().t()
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
return out
3 changes: 2 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16
| ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16
| vllm-w8a8 | vllm-fp8w8a8""",
| vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128
| triton-fp8w8a8-block128""",
)
parser.add_argument(
"--quant_cfg",
Expand Down

0 comments on commit 22da88c

Please sign in to comment.