Skip to content

Commit

Permalink
deepseekv3 bmm noquant and fix moe gemm bug. (#745)
Browse files Browse the repository at this point in the history
Co-authored-by: shihaobai <[email protected]>
Co-authored-by: shihaobai <[email protected]>
  • Loading branch information
3 people authored Feb 22, 2025
1 parent 808d832 commit c483b1e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,20 @@ def _post_load_weights(self) -> None:
and (not self.static_activation or self.input_scale is not None)
):
if self.weight_scale.ndim > 1:
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(self.device_id_)
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
self.weight = [
self.weight.transpose(0, 1).cuda(self.device_id_),
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight.cuda(self.device_id_).transpose(0, 1),
self.weight_scale,
self.input_scale,
]
else:
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
return
self.weight = self.weight.to(self.data_type_).transpose(0, 1).cuda(self.device_id_)

# 让 k dim 更连续,大多数split k 算法的算子可能能更快
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)


class MMWeight(MMWeightTpl):
Expand Down
10 changes: 5 additions & 5 deletions lightllm/common/fused_moe/grouped_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def grouped_matmul_kernel(
for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
# hint to Triton compiler to do proper loop pipelining
# tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])
# tl.multiple_of(b_ptrs, [16, 16])

if use_fp8_w8a8:
a = tl.load(a_ptrs, mask=(offs_am[None, :] < cur_m) & (offs_k[:, None] < k))
Expand Down Expand Up @@ -464,10 +464,10 @@ def grouped_matmul(
token_input_scale,
expert_to_weights_scale,
expert_to_weights_scale.stride(0)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
else 0,
expert_to_weights_scale.stride(1)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
else 0,
expert_to_weights_scale.stride(2)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3
Expand Down Expand Up @@ -532,10 +532,10 @@ def grouped_matmul(
token_input_scale,
expert_to_weights_scale,
expert_to_weights_scale.stride(0)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
else 0,
expert_to_weights_scale.stride(1)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
else 0,
expert_to_weights_scale.stride(2)
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3
Expand Down
1 change: 0 additions & 1 deletion lightllm/common/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
dtype=input_tensor.dtype,
)
else:
qweight = qweight.t().contiguous().t()
input_scale = input_scale.t().contiguous().t()
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ROWBMMWeightNoTp,
)
from functools import partial
from ..triton_kernel.weight_dequant import weight_dequant


class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
Expand Down Expand Up @@ -116,8 +117,15 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size):
def load_hf_weights(self, weights):
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
# for deepseek_v3, the bmm operator is not quantized
if self.quant_cfg.quantized_weight:
kv_b_proj_ = weight_dequant(
kv_b_proj_.cuda(),
weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix].cuda(),
).cpu()
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)

if (
self.quant_cfg.quantized_weight
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
Expand Down Expand Up @@ -184,15 +192,11 @@ def _init_qkvo(self):
f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
self.v_b_proj_ = ROWBMMWeight(
f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
if self.enable_cc_method:
self.cc_kv_b_proj_ = ROWMMWeight(
Expand Down
59 changes: 59 additions & 0 deletions lightllm/models/deepseek2/triton_kernel/weight_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# adapt from
# https://github.com/deepseek-ai/DeepSeek-V3/blob/f09f5fa321f5a421704136c0463b1eaca6557712/inference/kernel.py
import torch
import triton
import triton.language as tl
from triton import Config


def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y.to(torch.bfloat16)


@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)

0 comments on commit c483b1e

Please sign in to comment.