Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vit triton (fp + quant / tp + dp), custom image pre_process #663

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from abc import ABC, abstractmethod
from lightllm.utils.dist_utils import get_world_size, get_rank
Expand All @@ -20,6 +21,7 @@ class BaseWeightTpl(BaseWeight):
def __init__(self):
self.world_size_ = get_world_size()
self.tp_rank_ = get_rank()
self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_))

def load_hf_weights(self, weights):
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from .base_weight import BaseWeight
from lightllm.utils.dist_utils import get_world_size, get_rank
Expand All @@ -21,6 +22,7 @@ def __init__(
self.split_inter_size = split_inter_size
self.data_type_ = data_type
self.tp_rank_ = get_rank()
self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_))
self.experts_up_projs = [None] * self.n_routed_experts
self.experts_gate_projs = [None] * self.n_routed_experts
self.w2_list = [None] * self.n_routed_experts
Expand Down Expand Up @@ -113,10 +115,10 @@ def load_hf_weights(self, weights):
self._fuse()

def _cuda(self, cpu_tensor):
if self.tp_rank_ is None:
if self.device_id_ is None:
return cpu_tensor.contiguous().to(self.data_type_).cuda()
else:
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_)

def verify_load(self):
return self.w1 is not None and self.w2 is not None
18 changes: 10 additions & 8 deletions lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def set_quant_method(self, quant_method):

def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
if self.quant_method is not None:
return self.quant_method.apply(input_tensor, self.weight, self.bias, out)
return self.quant_method.apply(
input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger
)
if out is None:
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
Expand All @@ -47,9 +49,9 @@ def _post_load_weights(self):
if all(w is not None for w in [self.weight, self.weight_scale, self.input_scale]):
self.weight = self.quant_method.quantize((self.weight, self.weight_scale, self.input_scale))
else:
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.tp_rank_))
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
return
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
self.weight = self.weight.transpose(0, 1).cuda(self.device_id_)


class MMWeight(MMWeightTpl):
Expand Down Expand Up @@ -84,7 +86,7 @@ def load_hf_weights(self, weights):
self.weight = weight[self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
self.bias = bias.cuda(self.tp_rank_)
self.bias = bias.cuda(self.device_id_)

if STATIC_QUANT and self.weight_scale_name in weights:
weight_scale = weights[self.weight_scale_name].to(torch.float)[self.start : self.end]
Expand Down Expand Up @@ -120,7 +122,7 @@ def load_hf_weights(self, weights):
self.weight = weight[:, self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name]
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.tp_rank_)
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_)

if STATIC_QUANT and self.weight_scale_name in weights:
weight_scale = weights[self.weight_scale_name].to(torch.float)
Expand Down Expand Up @@ -191,7 +193,7 @@ def _fuse(self):

if self.has_bias:
if self.bias is None and all(b is not None for b in self.biases):
self.bias = torch.cat(self.biases, dim=0).cuda(self.tp_rank_)
self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_)
return self

def load_hf_weights(self, weights):
Expand Down Expand Up @@ -271,7 +273,7 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
return torch.addbmm(self.bias, input_tensor, self.weight, out=out)

def _post_load_weights(self):
self.weight = self.weight.cuda(self.tp_rank_)
self.weight = self.weight.cuda(self.device_id_)


class BMMWeight(BMMWeightTpl):
Expand Down Expand Up @@ -318,4 +320,4 @@ def __init__(
super().__init__(weight_name, data_type, split_n_embed, bias_name)

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)
self.weight = self.weight.transpose(0, 1).cuda(self.device_id_)
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def __init__(self, weight_name, data_type, bias_name=None):

def load_hf_weights(self, weights):
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.tp_rank_)
self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_)
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.tp_rank_)
self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_)

def verify_load(self):
load_ok = True
Expand All @@ -32,7 +32,7 @@ def __init__(self, weight_name, data_type, bias_name=None):

def load_hf_weights(self, weights):
if self.weight_name in weights:
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.tp_rank_)
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_)


class TpNormWeight(NormWeight):
Expand All @@ -41,10 +41,10 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights):
start = self.offset + self.split_n_embed * self.tp_rank_
end = self.offset + self.split_n_embed * (self.tp_rank_ + 1)
start = self.split_n_embed * self.device_id_
end = self.split_n_embed * (self.device_id_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.tp_rank_)
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_)
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.tp_rank_)
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_)
15 changes: 10 additions & 5 deletions lightllm/common/quantization/ppl_quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from .quantize_method import QuantizationMethod
from .registry import QUANTMETHODS
Expand All @@ -9,6 +10,7 @@ class PPLW4A16QuantizationMethod(QuantizationMethod):
def __init__(self, group_size=128):
super().__init__()
self.group_size = group_size
self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID"))

def quantize(self, weight: torch.Tensor):
"""
Expand All @@ -17,13 +19,13 @@ def quantize(self, weight: torch.Tensor):
qweight: [K, N//8] int32 (packed int4*8) new pack_order
q_scale: [K//group_size, N] int32
"""
weight = weight.to(dtype=torch.float16).cuda()
weight = weight.to(dtype=torch.float16).cuda(self.device_id_)
from lightllm_ppl_int4_kernel import int4_weight_encode

qweight_new, q_scale = int4_weight_encode(weight, self.group_size)
return qweight_new, q_scale

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
"""
input_tensor is activation: (M, K) float16
weights: [qweight, scale_weight]
Expand All @@ -38,7 +40,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
shape = (input_tensor.shape[0], qweight.shape[0] * 8)
dtype = input_tensor.dtype
device = input_tensor.device
out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False)
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False)
else:
out = torch.empty(shape, dtype, device=device)
from lightllm_ppl_int4_kernel import matmul_i4_fp16
from lightllm_ppl_int4_kernel import int4_weight_decode

Expand Down Expand Up @@ -71,9 +76,9 @@ def quantize(self, weight: torch.Tensor):
from flash_llm_fp6_llm import weight_quant_to_fp6

fp6_weight = weight_quant_to_fp6(quant_half, fp6_weight, True)
return fp6_weight.cuda(), scale.half().contiguous().cuda()
return fp6_weight.cuda(self.device_id_), scale.half().contiguous().cuda(self.device_id_)

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
""" """
from flash_llm_fp6_llm import linear_forward_cuda

Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/quantization/quantize_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ def quantize(self, weights: torch.Tensor):
pass

@abstractmethod
def apply(self, input_tensor, weight, bias=None, out=None):
def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True):
pass
6 changes: 4 additions & 2 deletions lightllm/common/quantization/torchao_quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from .quantize_method import QuantizationMethod
from .registry import QUANTMETHODS
Expand Down Expand Up @@ -32,15 +33,16 @@ def __init__(self):
assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it"
assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4"
self.quant_func = None
self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID"))

def quantize(self, weight: torch.Tensor):
""" """
dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
dummy_linear.weight = torch.nn.Parameter(weight.cuda())
dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_))
quantize_(dummy_linear, self.quant_func)
return dummy_linear.weight

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True):
return F.linear(input_tensor, weights, bias)


Expand Down
34 changes: 21 additions & 13 deletions lightllm/common/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from .quantize_method import QuantizationMethod
from .registry import QUANTMETHODS
Expand All @@ -15,6 +16,7 @@ class vLLMBaseQuantizationMethod(QuantizationMethod):
def __init__(self):
super().__init__()
assert HAS_VLLM, "vllm is not installed, you can't use quant api of it"
self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID"))

def quantize(self, weight: torch.Tensor):
""" """
Expand All @@ -32,14 +34,14 @@ def __init__(self):

def quantize(self, weight: torch.Tensor):
if isinstance(weight, tuple):
return (weight[0].transpose(0, 1).cuda(),) + weight[1:]
return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:]
weight = weight.float()
scale = weight.abs().max(dim=-1)[0] / 127
weight = weight.transpose(0, 1) / scale.reshape(1, -1)
weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8)
return weight.cuda(), scale.cuda()
return weight.cuda(self.device_id_), scale.cuda(self.device_id_)

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
input_scale = None
if len(weights) == 3:
qweight, weight_scale, input_scale = weights
Expand All @@ -52,9 +54,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
m = input_tensor.shape[0]
n = qweight.shape[1]
if out is None:
out = g_cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
else:
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
return out

Expand All @@ -69,31 +74,34 @@ def quantize(self, weight: torch.Tensor):
if self.is_moe:
return self.quantize_moe(weight)
qweight, weight_scale = ops.scaled_fp8_quant(
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=True
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
)
return qweight.transpose(0, 1), weight_scale

def quantize_moe(self, weight):
num_experts = weight.shape[0]
qweights = []
weight_scales = []
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda()
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_)
for i in range(num_experts):
qweight, weight_scale = ops.scaled_fp8_quant(
weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False
)
qweights[i] = qweight
weight_scales.append(weight_scale)
weight_scale = torch.cat(weight_scales, dim=0).reshape(-1)
return qweights, weight_scale

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
m = input_tensor.shape[0]
n = weights[0].shape[1]
if out is None:
out = g_cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
else:
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
return out
25 changes: 25 additions & 0 deletions lightllm/models/vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import importlib.util
from lightllm.utils.log_utils import init_logger
from lightllm.models.internvl.img_process import load_image as default_load_image

logger = init_logger(__name__)


def get_load_image_func(weight_dir):
global load_image
pre_process_path = os.path.join(weight_dir, "pre_process.py")
if os.path.exists(pre_process_path):
logger.info(f"Found pre_process.py in {weight_dir}, attempting to load load_image from it.")
spec = importlib.util.spec_from_file_location("pre_process", pre_process_path)
pre_process = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pre_process)
if hasattr(pre_process, "load_image"):
logger.info("load_image function replaced by the one in pre_process.py.")
return pre_process.load_image
else:
logger.info("load_image function not found in pre_process.py.")
else:
logger.info(f"pre_process.py not found in {weight_dir}, using default load_image.")

return default_load_image
32 changes: 32 additions & 0 deletions lightllm/models/vit/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo
from lightllm.common.req_manager import ReqManager


class LlamaInferStateInfo(InferStateInfo):
def __init__(self):
super().__init__()
self.position_cos = None
self.position_sin = None
self.other_kv_index = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy()
position_ids = torch.from_numpy(
np.concatenate(
[np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))],
axis=0,
)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
else:
position_ids = self.b_seq_len - 1
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1)
self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item()
return
Empty file.
Loading
Loading