diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 4a3accc45..dcfbb247c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,3 +1,4 @@ +import os import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_world_size, get_rank @@ -20,6 +21,7 @@ class BaseWeightTpl(BaseWeight): def __init__(self): self.world_size_ = get_world_size() self.tp_rank_ = get_rank() + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) def load_hf_weights(self, weights): pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index cd0f1a5f9..b7605b845 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -1,3 +1,4 @@ +import os import torch from .base_weight import BaseWeight from lightllm.utils.dist_utils import get_world_size, get_rank @@ -21,6 +22,7 @@ def __init__( self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_rank() + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts self.w2_list = [None] * self.n_routed_experts @@ -113,10 +115,10 @@ def load_hf_weights(self, weights): self._fuse() def _cuda(self, cpu_tensor): - if self.tp_rank_ is None: + if self.device_id_ is None: return cpu_tensor.contiguous().to(self.data_type_).cuda() else: - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) def verify_load(self): return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index c91dd8acb..a0f60823a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -28,7 +28,9 @@ def set_quant_method(self, quant_method): def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True): if self.quant_method is not None: - return self.quant_method.apply(input_tensor, self.weight, self.bias, out) + return self.quant_method.apply( + input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger + ) if out is None: shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype @@ -47,9 +49,9 @@ def _post_load_weights(self): if all(w is not None for w in [self.weight, self.weight_scale, self.input_scale]): self.weight = self.quant_method.quantize((self.weight, self.weight_scale, self.input_scale)) else: - self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.tp_rank_)) + self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_)) return - self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + self.weight = self.weight.transpose(0, 1).cuda(self.device_id_) class MMWeight(MMWeightTpl): @@ -84,7 +86,7 @@ def load_hf_weights(self, weights): self.weight = weight[self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] - self.bias = bias.cuda(self.tp_rank_) + self.bias = bias.cuda(self.device_id_) if STATIC_QUANT and self.weight_scale_name in weights: weight_scale = weights[self.weight_scale_name].to(torch.float)[self.start : self.end] @@ -120,7 +122,7 @@ def load_hf_weights(self, weights): self.weight = weight[:, self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name] - self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.tp_rank_) + self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_) if STATIC_QUANT and self.weight_scale_name in weights: weight_scale = weights[self.weight_scale_name].to(torch.float) @@ -191,7 +193,7 @@ def _fuse(self): if self.has_bias: if self.bias is None and all(b is not None for b in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(self.tp_rank_) + self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_) return self def load_hf_weights(self, weights): @@ -271,7 +273,7 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True): return torch.addbmm(self.bias, input_tensor, self.weight, out=out) def _post_load_weights(self): - self.weight = self.weight.cuda(self.tp_rank_) + self.weight = self.weight.cuda(self.device_id_) class BMMWeight(BMMWeightTpl): @@ -318,4 +320,4 @@ def __init__( super().__init__(weight_name, data_type, split_n_embed, bias_name) def _post_load_weights(self): - self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + self.weight = self.weight.transpose(0, 1).cuda(self.device_id_) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 029d723c9..5ba48516c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -12,9 +12,9 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.tp_rank_) + self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_) if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.tp_rank_) + self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_) def verify_load(self): load_ok = True @@ -32,7 +32,7 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.tp_rank_) + self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_) class TpNormWeight(NormWeight): @@ -41,10 +41,10 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): self.split_n_embed = split_n_embed def load_hf_weights(self, weights): - start = self.offset + self.split_n_embed * self.tp_rank_ - end = self.offset + self.split_n_embed * (self.tp_rank_ + 1) + start = self.split_n_embed * self.device_id_ + end = self.split_n_embed * (self.device_id_ + 1) if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.tp_rank_) + self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_) if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.tp_rank_) + self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_) diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index cb1d6a4ea..e87f2b1e2 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -9,6 +10,7 @@ class PPLW4A16QuantizationMethod(QuantizationMethod): def __init__(self, group_size=128): super().__init__() self.group_size = group_size + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ @@ -17,13 +19,13 @@ def quantize(self, weight: torch.Tensor): qweight: [K, N//8] int32 (packed int4*8) new pack_order q_scale: [K//group_size, N] int32 """ - weight = weight.to(dtype=torch.float16).cuda() + weight = weight.to(dtype=torch.float16).cuda(self.device_id_) from lightllm_ppl_int4_kernel import int4_weight_encode qweight_new, q_scale = int4_weight_encode(weight, self.group_size) return qweight_new, q_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): """ input_tensor is activation: (M, K) float16 weights: [qweight, scale_weight] @@ -38,7 +40,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): shape = (input_tensor.shape[0], qweight.shape[0] * 8) dtype = input_tensor.dtype device = input_tensor.device - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + else: + out = torch.empty(shape, dtype, device=device) from lightllm_ppl_int4_kernel import matmul_i4_fp16 from lightllm_ppl_int4_kernel import int4_weight_decode @@ -71,9 +76,9 @@ def quantize(self, weight: torch.Tensor): from flash_llm_fp6_llm import weight_quant_to_fp6 fp6_weight = weight_quant_to_fp6(quant_half, fp6_weight, True) - return fp6_weight.cuda(), scale.half().contiguous().cuda() + return fp6_weight.cuda(self.device_id_), scale.half().contiguous().cuda(self.device_id_) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): """ """ from flash_llm_fp6_llm import linear_forward_cuda diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index e50938f5a..e007e0e5f 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -11,5 +11,5 @@ def quantize(self, weights: torch.Tensor): pass @abstractmethod - def apply(self, input_tensor, weight, bias=None, out=None): + def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): pass diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index e2e93d1ea..e35b99b44 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -32,15 +33,16 @@ def __init__(self): assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ """ dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - dummy_linear.weight = torch.nn.Parameter(weight.cuda()) + dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) return dummy_linear.weight - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True): return F.linear(input_tensor, weights, bias) diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index fb24a3b8a..82b7790d1 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -15,6 +16,7 @@ class vLLMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use quant api of it" + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ """ @@ -32,14 +34,14 @@ def __init__(self): def quantize(self, weight: torch.Tensor): if isinstance(weight, tuple): - return (weight[0].transpose(0, 1).cuda(),) + weight[1:] + return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:] weight = weight.float() scale = weight.abs().max(dim=-1)[0] / 127 weight = weight.transpose(0, 1) / scale.reshape(1, -1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(), scale.cuda() + return weight.cuda(self.device_id_), scale.cuda(self.device_id_) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): input_scale = None if len(weights) == 3: qweight, weight_scale, input_scale = weights @@ -52,9 +54,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): m = input_tensor.shape[0] n = qweight.shape[1] if out is None: - out = g_cache_manager.alloc_tensor( - (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False - ) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out @@ -69,7 +74,7 @@ def quantize(self, weight: torch.Tensor): if self.is_moe: return self.quantize_moe(weight) qweight, weight_scale = ops.scaled_fp8_quant( - weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=True + weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) return qweight.transpose(0, 1), weight_scale @@ -77,23 +82,26 @@ def quantize_moe(self, weight): num_experts = weight.shape[0] qweights = [] weight_scales = [] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda() + qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) for i in range(num_experts): qweight, weight_scale = ops.scaled_fp8_quant( - weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False + weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False ) qweights[i] = qweight weight_scales.append(weight_scale) weight_scale = torch.cat(weight_scales, dim=0).reshape(-1) return qweights, weight_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = weights[0].shape[1] if out is None: - out = g_cache_manager.alloc_tensor( - (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False - ) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) return out diff --git a/lightllm/models/vit/__init__.py b/lightllm/models/vit/__init__.py new file mode 100644 index 000000000..cbe40d195 --- /dev/null +++ b/lightllm/models/vit/__init__.py @@ -0,0 +1,25 @@ +import os +import importlib.util +from lightllm.utils.log_utils import init_logger +from lightllm.models.internvl.img_process import load_image as default_load_image + +logger = init_logger(__name__) + + +def get_load_image_func(weight_dir): + global load_image + pre_process_path = os.path.join(weight_dir, "pre_process.py") + if os.path.exists(pre_process_path): + logger.info(f"Found pre_process.py in {weight_dir}, attempting to load load_image from it.") + spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) + pre_process = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pre_process) + if hasattr(pre_process, "load_image"): + logger.info("load_image function replaced by the one in pre_process.py.") + return pre_process.load_image + else: + logger.info("load_image function not found in pre_process.py.") + else: + logger.info(f"pre_process.py not found in {weight_dir}, using default load_image.") + + return default_load_image diff --git a/lightllm/models/vit/infer_struct.py b/lightllm/models/vit/infer_struct.py new file mode 100644 index 000000000..35a8e68bc --- /dev/null +++ b/lightllm/models/vit/infer_struct.py @@ -0,0 +1,32 @@ +import torch +import numpy as np +from lightllm.common.basemodel import InferStateInfo +from lightllm.common.req_manager import ReqManager + + +class LlamaInferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.other_kv_index = None + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + if self.is_prefill: + b_seq_len_numpy = self.b_seq_len.cpu().numpy() + b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() + position_ids = torch.from_numpy( + np.concatenate( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], + axis=0, + ) + ).cuda() + self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + position_ids = None + else: + position_ids = self.b_seq_len - 1 + self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) + self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) + self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() + return diff --git a/lightllm/models/vit/layer_infer/__init__.py b/lightllm/models/vit/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..0f2ccddf7 --- /dev/null +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -0,0 +1,59 @@ +import torch +import torch.functional as F +import torch.distributed as dist +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight + + +class ViTPostLayerInfer: + """ """ + + def __init__(self, tp_rank, world_size, network_config, mode): + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode + self.llm_hidden_size = network_config["llm_hidden_size"] + self.downsample_ratio = network_config["downsample_ratio"] + return + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight): + batch_size = vit_embeds.shape[0] + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds_norm = torch.nn.functional.layer_norm( + vit_embeds, + (vit_embeds.shape[-1],), + weight=layer_weight.layernorm_weight_, + bias=layer_weight.layernorm_bias_, + ) + + vit_embeds_1 = torch.addmm( + layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_ + ) + + vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1) + + vit_embeds_out = torch.addmm( + layer_weight.mlp1_3_bias_, + vit_embeds_gelu.view(-1, self.llm_hidden_size // self.world_size_), + layer_weight.mlp1_3_weight_, + beta=1.0 / self.world_size_, + ) + + if self.world_size_ == 1: + return vit_embeds_out.view(batch_size, -1, self.llm_hidden_size) + + dist.all_reduce(vit_embeds_out, op=dist.ReduceOp.SUM, async_op=False) + return vit_embeds_out.view(batch_size, -1, self.llm_hidden_size) diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..93de2fc0b --- /dev/null +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np + +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight + + +class ViTPreLayerInfer: + """ """ + + def __init__(self, tp_rank, world_size, network_config, mode): + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode + return + + def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): + target_dtype = layer_weight.patch_embedding_weight_.dtype + patch_embeds = F.conv2d( + pixel_values, + weight=layer_weight.patch_embedding_weight_, + bias=layer_weight.patch_embedding_bias_, + stride=layer_weight.patch_size, + ) + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = layer_weight.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat( + [layer_weight.position_embedding[:, :1, :], layer_weight._get_pos_embed(height, width)], dim=1 + ) + embeddings = embeddings + position_embedding.to(target_dtype) + if self.world_size_ == 1: + return embeddings + gather_embedding = torch.empty( + (embeddings.shape[2] * self.world_size_, batch_size, embeddings.shape[1]), + device=embeddings.device, + dtype=target_dtype, + ) + split_indexes = np.linspace(0, layer_weight.embed_dim, self.world_size_ + 1, dtype=np.int64) + dist.all_gather( + [gather_embedding[split_indexes[i] : split_indexes[i + 1], :, :] for i in range(self.world_size_)], + embeddings.permute(2, 0, 1).contiguous(), + group=None, + async_op=False, + ) + return gather_embedding.permute(1, 2, 0).contiguous() diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..f13fcf428 --- /dev/null +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -0,0 +1,147 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from typing import Tuple +from functools import partial +import triton + +from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd + + +class ViTTransformerLayerInfer: + """ """ + + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): + self.eps_ = network_config["layer_norm_eps"] + self.head_num = network_config["num_attention_heads"] + self.tp_padding_head_num = network_config["padding_head_num"] // world_size + self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] + self.embed_dim_ = network_config["hidden_size"] + self.qk_norm = network_config["qk_normalization"] + self.tp_padding_embed_dim_ = self.tp_padding_head_num * self.head_dim_ + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode + self.layer_num_ = layer_num + return + + def norm(self, input, weight): + input_dtype = input.dtype + input_shape = input.shape + input = input.view(-1, self.tp_padding_head_num * self.head_dim_) + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.eps_) + out = weight * input.to(input_dtype) + out = out.reshape(input_shape) + return out + + def tp_norm(self, input, weight): + input_shape = input.shape + input = input.view(-1, self.tp_padding_head_num * self.head_dim_) + input_dtype = input.dtype + input = input.to(torch.float32) + tp_variance = input.pow(2).sum(-1, keepdim=True) + if self.world_size_ > 1: + dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) + variance = tp_variance / self.embed_dim_ + input = input * torch.rsqrt(variance + self.eps_) + out = weight * input.to(input_dtype) + out = out.reshape(input_shape) + return out + + def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + if layer_weight.norm_type == "rms_norm": + b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_) + else: + b = torch.nn.functional.layer_norm( + input, + normalized_shape=[1024], + weight=layer_weight.att_norm_weight_.weight, + bias=layer_weight.att_norm_weight_.bias, + eps=layer_weight.layer_norm_eps, + ) + return b + + def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + if layer_weight.norm_type == "rms_norm": + return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_) + else: + return torch.nn.functional.layer_norm( + input, + normalized_shape=[1024], + weight=layer_weight.ffn_norm_weight_.weight, + bias=layer_weight.ffn_norm_weight_.bias, + eps=layer_weight.layer_norm_eps, + ) + + def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight) + k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight) + return q_norm, k_norm + + def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + batch_size = input.shape[0] + seq_len = input.shape[1] + qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) + qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_) + q, k, v = qkv.unbind(2) + return q, k, v + + def _context_attention_kernel(self, q, k, v) -> torch.Tensor: + out = torch.empty_like(q) + batch_size = q.shape[0] + seq_len = q.shape[1] + flash_attention_fwd(q, k, v, out) + return out.reshape(batch_size, seq_len, -1) + + def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + batch_size = input.shape[0] + seq_len = input.shape[1] + o_tensor = layer_weight.o_proj.mm( + input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=False + ) + if layer_weight.use_ls: + o_tensor *= layer_weight.ls1 + return o_tensor.reshape((batch_size, seq_len, -1)) + + def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) + ffn1_out = torch.nn.functional.gelu(fc1) + input_shape = input.shape + input = None + ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False) + if layer_weight.use_ls: + ffn2_out *= layer_weight.ls2 + ffn1_out = None + return ffn2_out.reshape(input_shape) + + def _context_attention(self, input_embding, layer_weight): + input1 = self._att_norm(input_embding, layer_weight) + q, k, v = self._get_qkv(input1, layer_weight) + if layer_weight.qk_norm: + q, k = self._qk_norm(q, k, layer_weight) + o = self._context_attention_kernel(q, k, v) + o = self._get_o(o, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + input_embding.add_(o) + return + + def _context_ffn(self, input_embdings, layer_weight): + input1 = self._ffn_norm(input_embdings, layer_weight) + ffn_out = self._ffn(input1, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + input_embdings.add_(ffn_out) + return + + def forward(self, input_embdings, layer_weight): + self._context_attention(input_embdings, layer_weight=layer_weight) + self._context_ffn(input_embdings, layer_weight) + return input_embdings diff --git a/lightllm/models/vit/layer_weights/__init__.py b/lightllm/models/vit/layer_weights/__init__.py new file mode 100644 index 000000000..08bcd04cf --- /dev/null +++ b/lightllm/models/vit/layer_weights/__init__.py @@ -0,0 +1,38 @@ +import os +import importlib.util + +# 默认的load_image函数 +def default_load_image(image_path): + print(f"Loading image using default function: {image_path}") + # 默认的加载图像逻辑(这里只是示例) + return image_path + + +# 用户提供的目录路径 +directory = "./user_directory" + +# 设定默认的load_image函数为default_load_image +load_image = default_load_image + +# 检查目录中是否有pre_process.py文件 +pre_process_path = os.path.join(directory, "pre_process.py") + +if os.path.exists(pre_process_path): + print(f"Found pre_process.py in {directory}, attempting to load load_image from it.") + + # 使用importlib来加载模块 + spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) + pre_process = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pre_process) + + # 如果pre_process.py中有load_image函数,则替换默认函数 + if hasattr(pre_process, "load_image"): + load_image = pre_process.load_image + print("load_image function replaced by the one in pre_process.py.") + else: + print("load_image function not found in pre_process.py.") +else: + print(f"pre_process.py not found in {directory}, using default load_image.") + +# 使用当前的load_image函数 +image = load_image("path/to/image.jpg") diff --git a/lightllm/models/vit/layer_weights/hf_load_utils.py b/lightllm/models/vit/layer_weights/hf_load_utils.py new file mode 100644 index 000000000..3fa82af8e --- /dev/null +++ b/lightllm/models/vit/layer_weights/hf_load_utils.py @@ -0,0 +1,68 @@ +import torch +import os +import gc +from safetensors import safe_open +import lightllm.utils.petrel_helper as utils + + +def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): + if use_safetensors: + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + weights = {k: weights.get_tensor(k) for k in weights.keys()} + else: + weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + new_weight = {} + for k, v in weights.items(): + if "language_model." in k: + new_weight[k[len("language_model.") :]] = v + else: + new_weight[k] = v + del weights + weights = new_weight + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights + gc.collect() + + +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + if weight_dict: + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + return + use_safetensors = True + files = utils.PetrelHelper.list(weight_dir, extension="all") + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) + if len(candidate_files) == 0: + use_safetensors = False + candidate_files = list(sorted(filter(lambda x: x.endswith(".bin"), files))) + candidate_files = candidate_files[0:45] + [candidate_files[-1]] + print(candidate_files) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + from functools import partial + from multiprocessing.pool import ThreadPool as Pool + + partial_func = partial( + load_func, + use_safetensors=use_safetensors, + pre_post_layer=pre_post_layer, + transformer_layer_list=transformer_layer_list, + weight_dir=weight_dir, + ) # noqa + worker = int(os.environ.get("LOADWORKER", 1)) + with Pool(worker) as p: + _ = p.map(partial_func, candidate_files) + return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..f56a884a6 --- /dev/null +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,89 @@ +import os +import torch +import numpy as np +import torch.nn.functional as F +from lightllm.common.basemodel import PreAndPostLayerWeight + + +class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, tp_rank, world_size, data_type, network_config, mode): + super().__init__(tp_rank, world_size, data_type, network_config, mode) + self.embed_dim = self.network_config_["hidden_size"] + self.image_size = self.network_config_["image_size"] + self.patch_size = self.network_config_["patch_size"] + self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) + return + + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + + def _get_pos_embed(self, H, W): + pos_embed = self.position_embedding[:, 1:, :] + target_dtype = pos_embed.dtype + pos_embed = ( + pos_embed.float() + .reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1) + .permute(0, 3, 1, 2) + ) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + .to(target_dtype) + ) + return pos_embed + + def load_hf_weights(self, weights): + split_indexes = np.linspace(0, self.embed_dim, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "vision_model.embeddings.class_embedding" in weights: + self.class_embedding = self._cuda( + weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] + ) + if "vision_model.embeddings.position_embedding" in weights: + self.position_embedding = self._cuda( + weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] + ) + if "vision_model.embeddings.patch_embedding.weight" in weights: + self.patch_embedding_weight_ = self._cuda( + weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] + ) + if "vision_model.embeddings.patch_embedding.bias" in weights: + self.patch_embedding_bias_ = self._cuda( + weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] + ) + + if "mlp1.0.weight" in weights: + self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) + if "mlp1.0.bias" in weights: + self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) + + split_indexes = np.linspace(0, self.llm_hidden_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + + if "mlp1.1.weight" in weights: + self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() + if "mlp1.1.bias" in weights: + self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) + + if "mlp1.3.weight" in weights: + self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() + if "mlp1.3.bias" in weights: + self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) + + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.class_embedding, + self.position_embedding, + self.patch_embedding_weight_, + self.patch_embedding_bias_, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2676ff3ed --- /dev/null +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -0,0 +1,167 @@ +import os +import torch +import math +import numpy as np +import torch.nn.functional as F +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + NormWeight, + MultiROWMMWeight, + TpNormWeight, +) + + +class ViTTransformerLayerWeight(TransformerLayerWeight): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) + self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) + + return + + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + + def _parse_config(self): + self.padding_hidden_size = self.network_config_["padding_hidden_size"] + self.qk_norm = self.network_config_["qk_normalization"] + self.use_ls = self.network_config_.get("use_ls", False) + self.qkv_bias = self.network_config_.get("qkv_bias", True) + self.layer_norm_eps = self.network_config_.get("layer_norm_eps", 1e-6) + self.norm_type = self.network_config_.get("norm_type", "layer_norm") + + def _init_weight_names(self): + self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" + + self._q_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q.weight" + self._k_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k.weight" + self._v_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.v.weight" + + if self.qkv_bias: + self._q_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q.bias" + self._k_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k.bias" + self._v_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.v.bias" + else: + self._q_bias_name = None + self._k_bias_name = None + self._v_bias_name = None + + self._o_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" + self._o_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" + + self.fc1_weight_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" + self.fc1_bias_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" + self.fc2_weight_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" + self.fc2_bias_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" + + self._ls1_name = f"vision_model.encoder.layers.{self.layer_num_}.ls1" + self._ls2_name = f"vision_model.encoder.layers.{self.layer_num_}.ls2" + + self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" + self._ffn_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" + + if self.norm_type == "layer_norm": + self._att_norm_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" + self._ffn_norm_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" + else: + self._att_norm_bias_name = None + self._ffn_norm_bias_name = None + + if self.qk_norm: + self._q_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" + self._k_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" + self._q_norm_bias_name = None + self._k_norm_bias_name = None + + def _init_weight(self): + self._init_qkv() + self._init_o() + self._init_ffn() + self._init_norm() + + def _init_qkv(self): + n_embed = self.network_config_["hidden_size"] + qkv_split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.qkv_proj = MultiROWMMWeight( + [self._q_weight_name, self._k_weight_name, self._v_weight_name], + self.data_type_, + qkv_split_n_embed, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + ) + + def _init_o(self): + n_embed = self.network_config_["hidden_size"] + o_split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.o_proj = COLMMWeight(self._o_weight_name, self.data_type_, o_split_n_embed, bias_name=self._o_bias_name) + + def _init_ffn(self): + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + self.ffn_1_proj_ = ROWMMWeight( + self.fc1_weight_name_, + self.data_type_, + split_inter_size, + bias_name=self.fc1_bias_name_, + ) + self.ffn_2_proj_ = COLMMWeight( + self.fc2_weight_name_, self.data_type_, split_inter_size, bias_name=self.fc2_bias_name_ + ) + + def _init_norm(self): + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + if self.qk_norm: + n_embed = self.network_config_["hidden_size"] + split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) + self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) + + def load_hf_weights(self, weights): + if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: + n_embed = self.network_config_["hidden_size"] + att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] + att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) + q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) + k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) + v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) + del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] + weights[self._q_weight_name] = q_weight_ + weights[self._k_weight_name] = k_weight_ + weights[self._v_weight_name] = v_weight_ + + if self._o_weight_name in weights: + weights[self._o_weight_name] = F.pad(weights[self._o_weight_name], (0, self.padding_hidden_size, 0, 0)) + + if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: + n_embed = self.network_config_["hidden_size"] + att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] + att_qkv_dense_bias = F.pad(att_qkv_dense_bias, (0, self.padding_hidden_size)).reshape(3, -1) + q_bias_ = att_qkv_dense_bias[0] + k_bias_ = att_qkv_dense_bias[1] + v_bias_ = att_qkv_dense_bias[2] + weights[self._q_bias_name] = q_bias_ + weights[self._k_bias_name] = k_bias_ + weights[self._v_bias_name] = v_bias_ + del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] + + if self._q_norm_weight_name in weights: + weights[self._q_norm_weight_name] = F.pad(weights[self._q_norm_weight_name], (0, self.padding_hidden_size)) + + if self._k_norm_weight_name in weights: + weights[self._k_norm_weight_name] = F.pad(weights[self._k_norm_weight_name], (0, self.padding_hidden_size)) + + if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: + ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] + self.ls1 = self._cuda(ls1) + + if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: + ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] + self.ls2 = self._cuda(ls2) + self.use_ls = True + + return super().load_hf_weights(weights) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py new file mode 100644 index 000000000..025809c95 --- /dev/null +++ b/lightllm/models/vit/model.py @@ -0,0 +1,174 @@ +import os +import json +import torch +from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer +from lightllm.models.vit.layer_infer.post_layer_infer import ViTPostLayerInfer +from lightllm.models.vit.layer_infer.transformer_layer_infer import ViTTransformerLayerInfer +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight +from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight +from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights +from lightllm.utils.log_utils import init_logger +from lightllm.models.vit import get_load_image_func +import torchvision.transforms as T +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from PIL import Image +from typing import List, Union +from io import BytesIO +from rpyc.utils.classic import obtain +from lightllm.common.quantization import Quantcfg + + +logger = init_logger(__name__) + + +class VisionTransformer: + + # weight class + pre_and_post_weight_class = ViTPreAndPostLayerWeight + transformer_weight_class = ViTTransformerLayerWeight + + # infer class + pre_layer_infer_class = ViTPreLayerInfer + transformer_layer_infer_class = ViTTransformerLayerInfer + post_layer_infer_class = ViTPostLayerInfer + + def __init__(self, kvargs): + self.tp_rank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["world_size"] + self.weight_dir_ = kvargs["weight_dir"] + self.load_way = kvargs.get("load_way", "HF") + self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + self.weight_dict = kvargs.get("weight_dict", None) + self.data_type = kvargs.get("data_type", "float16") + self.quant_type = kvargs.get("quant_type", None) + self.quant_cfg_path = kvargs.get("quant_cfg", None) + self.load_image_func = get_load_image_func(self.weight_dir_) + + self._init_datatype() + self._init_config() + self._padding_hidden_size() + self._init_quant() + self._init_weights() + self._init_infer_layer() + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + self.select_layer = self.config["select_layer"] + self.config["vision_config"]["llm_hidden_size"] = self.config["llm_config"]["hidden_size"] + self.config["vision_config"]["downsample_ratio"] = self.config["downsample_ratio"] + self.config = self.config["vision_config"] + self.layers_num = self.config["num_hidden_layers"] + return + + def _padding_hidden_size(self): + self.config["padding_hidden_size"] = 0 + self.config["padding_head_num"] = self.config["num_attention_heads"] + + head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] + if self.config["num_attention_heads"] % self.world_size_ != 0: + padding_head_num = ( + self.config["num_attention_heads"] + self.world_size_ - 1 + ) // self.world_size_ * self.world_size_ - self.config["num_attention_heads"] + self.config["padding_hidden_size"] = padding_head_num * head_dim + self.config["padding_head_num"] += padding_head_num + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class( + self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.tp_rank_, + self.world_size_, + self.data_type, + network_config=self.config, + mode=self.mode, + quant_cfg=self.quant_cfg, + ) + for i in range(self.config["num_hidden_layers"]) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + return + + def _init_quant(self): + self.quant_cfg = Quantcfg(self.config["num_hidden_layers"], self.quant_type, self.quant_cfg_path) + logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + self.post_infer = self.post_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + self.layers_infer = [ + self.transformer_layer_infer_class( + i, tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + for i in range(self.config["num_hidden_layers"]) + ] + return + + def _init_datatype(self): + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + + def forward(self, pixel_values): + input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) + for i in range(self.layers_num + self.select_layer + 1): + input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) + input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) + return input_embs + + @torch.no_grad() + def encode(self, image_uuids: List): + img_tensors = [] + valid_ids = [] + valid_id = 0 + uuids = [] + for i, url in enumerate(image_uuids): + if isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + t = self.load_image_func(image_data) + img_tensors.append(t) + else: + raise Exception("Unsupport input types: {} for {}".format(type(url), url)) + + cur_num = img_tensors[-1].shape[0] + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + pixel_values = imgs.cuda().to(dtype=self.data_type) + print(pixel_values.shape, pixel_values.dtype) + all_img_embeds = self.forward(pixel_values) + return all_img_embeds, uuids, valid_ids + + def cuda(self): + return self + + def load_model(self, weight_dir): + pass diff --git a/lightllm/models/vit/triton_kernel/__init__.py b/lightllm/models/vit/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py new file mode 100644 index 000000000..062ff51f5 --- /dev/null +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -0,0 +1,191 @@ +import torch + +import triton +import triton.language as tl +import math +import torch.nn.functional as F + +TESLA = "Tesla" in torch.cuda.get_device_name(0) + + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + sm_scale, + seq_len, + Out, + q_stride_b, + q_stride_s, + q_stride_h, + q_stride_d, + k_stride_b, + k_stride_s, + k_stride_h, + k_stride_d, + v_stride_b, + v_stride_s, + v_stride_h, + v_stride_d, + o_stride_b, + o_stride_s, + o_stride_h, + o_stride_d, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(2) + cur_head = tl.program_id(1) + start_m = tl.program_id(0) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = cur_batch * q_stride_b + cur_head * q_stride_h + offs_m[:, None] * q_stride_s + offs_d[None, :] + q = tl.load(Q + off_q, mask=offs_m[:, None] < seq_len, other=0.0) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + off_k = ( + cur_batch * k_stride_b + + (start_n + offs_n[None, :]) * k_stride_s + + cur_head * k_stride_h + + offs_d[:, None] + ) + k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk += tl.where((start_n + offs_n[None, :]) < seq_len, 0, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.maximum(tl.max(qk, 1), l_i) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + acc_scale = tl.exp(m_i - m_ij) + acc = acc * acc_scale[:, None] + + # update acc + off_v = ( + cur_batch * v_stride_b + + (start_n + offs_n[:, None]) * v_stride_s + + cur_head * v_stride_h + + offs_d[None, :] + ) + v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < seq_len, other=0.0) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + l_i_new = tl.exp(l_i - m_ij) + l_ij + l_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - l_i) + acc = acc * o_scale[:, None] + # initialize pointers to output + off_o = cur_batch * o_stride_b + offs_m[:, None] * o_stride_s + cur_head * o_stride_h + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < seq_len) + return + + @torch.no_grad() + def flash_attention_fwd( + q, + k, + v, + o, + ): + BLOCK = 64 + # shape constraints + batch_size, seq_len, head_num, head_dim = q.shape + + sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数 + # grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head, + grid = (triton.cdiv(seq_len, BLOCK), head_num, batch_size) # batch, head, + num_warps = 4 + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + BLOCK_M=BLOCK, + BLOCK_DMODEL=head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=2, + ) + return + +else: + raise Exception("error triton version!") + + +def torch_att(q, k, v): + head_dim = q.shape[-1] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + scale = head_dim ** -0.5 + attn = (q * scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + out = out.transpose(1, 2).contiguous() + return out + + +def test(): + import torch + import numpy as np + + B, L, H, D = 4, 1025, 7, 128 + dtype = torch.float16 + q = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + torch_out = torch_att(q, k, v) + import time + + torch.cuda.synchronize() + a = time.time() + for i in range(100): + flash_attention_fwd(q, k, v, o) + # o = torch_att(q, k, v) + torch.cuda.synchronize() + b = time.time() + # print(o.shape, torch_out.shape) + print((b - a) / 100 * 1000) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7145167e6..28e066425 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -197,7 +197,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") parser.add_argument( - "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch" + "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2" diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4bb1a1a2a..6a0adc6f9 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -77,6 +77,7 @@ def init_model(self, kvargs): self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None) max_total_token_num = kvargs["max_total_token_num"] + os.environ["CURRENT_DEVICE_ID"] = str(self.tp_rank) torch.cuda.set_device(self.tp_rank) dist.init_process_group( diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b03a27f6f..6ab3e70df 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -70,6 +70,8 @@ async def wait_to_model_ready(self): "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "visual_gpu_ids": self.args.visual_gpu_ids, + "quant_type": self.args.quant_type, + "quant_cfg": self.args.quant_cfg, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 08c9df17f..1bce4b926 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -11,6 +11,7 @@ from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel +from lightllm.models.vit.model import VisionTransformer from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed @@ -32,19 +33,18 @@ def exposed_init_model(self, kvargs): visual_nccl_port = kvargs["visual_nccl_port"] self.vit_rank_id = kvargs["vit_rank_id"] self.cache_client = rpyc.connect("localhost", self.cache_port) + self.data_type = kvargs["data_type"] torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) - if self.vit_tp != 1: - dist.init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{visual_nccl_port}", - rank=self.tp_rank_id, - world_size=self.vit_tp, - ) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", + rank=self.tp_rank_id, + world_size=self.vit_tp, + ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + os.environ["CURRENT_DEVICE_ID"] = str(visual_gpu_ids[self.vit_rank_id]) - if self.vit_tp != 1: - raise ValueError(f"ERROR: Not support vit_tp value: {self.vit_tp}") try: self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -54,7 +54,16 @@ def exposed_init_model(self, kvargs): elif self.model_type == "llava": self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": - self.model = InternVLVisionModel() + kvargs = { + "tp_rank": self.tp_rank_id, + "world_size": self.vit_tp, + "weight_dir": weight_dir, + "data_type": self.data_type, + "quant_type": kvargs["quant_type"], + "quant_cfg": kvargs["quant_cfg"], + } + self.model = VisionTransformer(kvargs) + # self.model = InternVLVisionModel() else: raise Exception(f"can not support {self.model_type} now") diff --git a/test/model/model_infer_vit.py b/test/model/model_infer_vit.py new file mode 100644 index 000000000..65d0a7915 --- /dev/null +++ b/test/model/model_infer_vit.py @@ -0,0 +1,72 @@ +import numpy as np +from multiprocessing import Queue +import multiprocessing +import os +import time + +from lightllm.models.vit.model import VisionTransformer + + +def test_model_inference(world_size, weight_dir, quant_type=None): + workers = [] + for rank_id in range(world_size): + kvargs = { + "tp_rank": rank_id, + "world_size": world_size, + "weight_dir": weight_dir, + "data_type": "bf16", + "quant_type": quant_type, + "quant_cfg": None, + } + + proc = multiprocessing.Process(target=tppart_model_infer, args=(kvargs,)) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + return + + +def tppart_model_infer(model_kvargs): + import torch + from lightllm.distributed import set_custom_reduce + import torch.distributed as dist + + rank_id = model_kvargs["tp_rank"] + world_size = model_kvargs["world_size"] + + torch.cuda.set_device(rank_id) + os.environ["CURRENT_DEVICE_ID"] = str(rank_id) + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) + set_custom_reduce() + dist.barrier() + torch.cuda.empty_cache() + model_part = VisionTransformer(model_kvargs) + test_data = torch.randn((13, 3, 448, 448)).cuda().to(torch.bfloat16) + # warm up + torch.cuda.synchronize() + for i in range(10): + model_part.forward(test_data) + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.time() + for i in range(50): + model_part.forward(test_data) + torch.cuda.synchronize() + end_time = time.time() + + if rank_id == 0: + print("time total cost(ms):", (end_time - start_time) / 50 * 1000) + + return + + +if __name__ == "__main__": + import torch + + world_size = 2 + weight_dir = "your_multimodal_vit_path" + torch.multiprocessing.set_start_method("spawn") + test_model_inference(world_size, weight_dir, "vllm-w8a8")