From 1e6c7b85f8bf67b729244f6d756dfc769d6addd5 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 9 Dec 2024 21:00:50 +0800 Subject: [PATCH 01/17] vit triton (draft) --- lightllm/models/vit/__init__.py | 4 + lightllm/models/vit/infer_struct.py | 32 +++ lightllm/models/vit/layer_infer/__init__.py | 0 .../vit/layer_infer/post_layer_infer.py | 59 +++++ .../models/vit/layer_infer/pre_layer_infer.py | 48 ++++ .../layer_infer/transformer_layer_infer.py | 163 ++++++++++++++ lightllm/models/vit/layer_weights/__init__.py | 0 .../models/vit/layer_weights/hf_load_utils.py | 68 ++++++ .../pre_and_post_layer_weight.py | 88 ++++++++ .../layer_weights/transformer_layer_weight.py | 211 ++++++++++++++++++ lightllm/models/vit/model_vit.py | 176 +++++++++++++++ lightllm/models/vit/triton_kernel/__init__.py | 0 .../vit/triton_kernel/flashattention_nopad.py | 179 +++++++++++++++ .../visualserver/model_infer/model_rpc.py | 24 +- 14 files changed, 1045 insertions(+), 7 deletions(-) create mode 100644 lightllm/models/vit/__init__.py create mode 100644 lightllm/models/vit/infer_struct.py create mode 100644 lightllm/models/vit/layer_infer/__init__.py create mode 100644 lightllm/models/vit/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/vit/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/vit/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/vit/layer_weights/__init__.py create mode 100644 lightllm/models/vit/layer_weights/hf_load_utils.py create mode 100644 lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/vit/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/vit/model_vit.py create mode 100644 lightllm/models/vit/triton_kernel/__init__.py create mode 100644 lightllm/models/vit/triton_kernel/flashattention_nopad.py diff --git a/lightllm/models/vit/__init__.py b/lightllm/models/vit/__init__.py new file mode 100644 index 000000000..8fac3743c --- /dev/null +++ b/lightllm/models/vit/__init__.py @@ -0,0 +1,4 @@ +import os +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) diff --git a/lightllm/models/vit/infer_struct.py b/lightllm/models/vit/infer_struct.py new file mode 100644 index 000000000..35a8e68bc --- /dev/null +++ b/lightllm/models/vit/infer_struct.py @@ -0,0 +1,32 @@ +import torch +import numpy as np +from lightllm.common.basemodel import InferStateInfo +from lightllm.common.req_manager import ReqManager + + +class LlamaInferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.other_kv_index = None + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + if self.is_prefill: + b_seq_len_numpy = self.b_seq_len.cpu().numpy() + b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() + position_ids = torch.from_numpy( + np.concatenate( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], + axis=0, + ) + ).cuda() + self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + position_ids = None + else: + position_ids = self.b_seq_len - 1 + self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) + self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) + self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() + return diff --git a/lightllm/models/vit/layer_infer/__init__.py b/lightllm/models/vit/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..0f2ccddf7 --- /dev/null +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -0,0 +1,59 @@ +import torch +import torch.functional as F +import torch.distributed as dist +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight + + +class ViTPostLayerInfer: + """ """ + + def __init__(self, tp_rank, world_size, network_config, mode): + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode + self.llm_hidden_size = network_config["llm_hidden_size"] + self.downsample_ratio = network_config["downsample_ratio"] + return + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight): + batch_size = vit_embeds.shape[0] + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds_norm = torch.nn.functional.layer_norm( + vit_embeds, + (vit_embeds.shape[-1],), + weight=layer_weight.layernorm_weight_, + bias=layer_weight.layernorm_bias_, + ) + + vit_embeds_1 = torch.addmm( + layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_ + ) + + vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1) + + vit_embeds_out = torch.addmm( + layer_weight.mlp1_3_bias_, + vit_embeds_gelu.view(-1, self.llm_hidden_size // self.world_size_), + layer_weight.mlp1_3_weight_, + beta=1.0 / self.world_size_, + ) + + if self.world_size_ == 1: + return vit_embeds_out.view(batch_size, -1, self.llm_hidden_size) + + dist.all_reduce(vit_embeds_out, op=dist.ReduceOp.SUM, async_op=False) + return vit_embeds_out.view(batch_size, -1, self.llm_hidden_size) diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..c9134e3c2 --- /dev/null +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np + +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight +from lightllm.common.basemodel import PreLayerInferTpl + + +class ViTPreLayerInfer(PreLayerInferTpl): + """ """ + + def __init__(self, tp_rank, world_size, network_config, mode): + super().__init__(tp_rank, world_size, network_config, mode) + return + + def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): + target_dtype = layer_weight.patch_embedding_weight_.dtype + patch_embeds = F.conv2d( + pixel_values, + weight=layer_weight.patch_embedding_weight_, + bias=layer_weight.patch_embedding_bias_, + stride=layer_weight.patch_size, + ) + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = layer_weight.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat( + [layer_weight.position_embedding[:, :1, :], layer_weight._get_pos_embed(height, width)], dim=1 + ) + embeddings = embeddings + position_embedding.to(target_dtype) + if self.world_size_ == 1: + return embeddings + gather_embedding = torch.empty( + (embeddings.shape[2] * self.world_size_, batch_size, embeddings.shape[1]), + device=embeddings.device, + dtype=target_dtype, + ) + split_indexes = np.linspace(0, layer_weight.embed_dim, self.world_size_ + 1, dtype=np.int64) + dist.all_gather( + [gather_embedding[split_indexes[i] : split_indexes[i + 1], :, :] for i in range(self.world_size_)], + embeddings.permute(2, 0, 1).contiguous(), + group=None, + async_op=False, + ) + return gather_embedding.permute(1, 2, 0).contiguous() diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..5ffcbd841 --- /dev/null +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -0,0 +1,163 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from typing import Tuple +from functools import partial +import triton + +from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd + + +class ViTTransformerLayerInfer: + """ """ + + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): + self.eps_ = network_config["layer_norm_eps"] + self.head_num = network_config["num_attention_heads"] + self.tp_padding_head_num = network_config["padding_head_num"] // world_size + self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] + self.embed_dim_ = network_config["hidden_size"] + self.qk_norm = network_config["qk_normalization"] + self.tp_padding_embed_dim_ = self.tp_padding_head_num * self.head_dim_ + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode + self.layer_num_ = layer_num + return + + def norm(self, input, weight): + input_dtype = input.dtype + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.eps_) + return weight * input.to(input_dtype) + + def tp_norm(self, input, weight): + input_dtype = input.dtype + input = input.to(torch.float32) + tp_variance = input.pow(2).sum(-1, keepdim=True) + dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) + variance = tp_variance / self.embed_dim_ + input = input * torch.rsqrt(variance + self.eps_) + return weight * input.to(input_dtype) + + def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + if layer_weight.norm_type == "rms_norm": + b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_, eps=self.eps_) + else: + b = torch.nn.functional.layer_norm( + input, + normalized_shape=[1024], + weight=layer_weight.att_norm_weight_, + bias=layer_weight.att_norm_bias_, + eps=layer_weight.layer_norm_eps, + ) + # b = torch.empty_like(input) + # rms_norm(b, input, layer_weight.att_norm_weight_, self.eps_ , 4) + return b + + def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + # return self.norm(input, layer_weight.ffn_norm_weight_) + if layer_weight.norm_type == "rms_norm": + return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_, eps=self.eps_) + else: + return torch.nn.functional.layer_norm( + input, + normalized_shape=[1024], + weight=layer_weight.ffn_norm_weight_, + bias=layer_weight.ffn_norm_bias_, + eps=layer_weight.layer_norm_eps, + ) + + def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + q_norm = self.tp_norm(q, layer_weight.q_norm_weight_) + k_norm = self.tp_norm(k, layer_weight.k_norm_weight_) + # import numpy as np + # q_norm = rmsnorm_forward(q, weight=layer_weight.q_norm_weight_, eps=self.eps_) + # k_norm = rmsnorm_forward(k, weight=layer_weight.k_norm_weight_, eps=self.eps_) + return q_norm, k_norm + + def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + batch_size = input.shape[0] + seq_len = input.shape[1] + if layer_weight.qkv_bias: + qkv = torch.addmm(layer_weight.qkv_bias_, input.view(-1, self.embed_dim_), layer_weight.qkv_weight_).view( + batch_size, seq_len, 3, -1, self.head_dim_ + ) + q, k, v = qkv.unbind(2) + else: + q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_).view(batch_size, seq_len, -1) + k = torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_).view(batch_size, seq_len, -1) + v = torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_).view(batch_size, seq_len, -1) + return q, k, v + + def _context_attention_kernel(self, q, k, v) -> torch.Tensor: + out = torch.empty_like(q) + batch_size = q.shape[0] + seq_len = q.shape[1] + # import time + # torch.cuda.synchronize() + # a = time.time() + # for i in range(100): + flash_attention_fwd(q, k, v, out) + # torch.cuda.synchronize() + # b = time.time() + # print(f"{self.layer_num_} The time is {(b - a) * 10}") + return out.reshape(batch_size, seq_len, -1) + + def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + batch_size = input.shape[0] + seq_len = input.shape[1] + o_tensor = torch.addmm( + layer_weight.o_bias_, + input.view(-1, self.tp_padding_head_num * self.head_dim_), + layer_weight.o_weight_, + ) + return o_tensor.reshape((batch_size, seq_len, -1)) + + def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: + fc1 = torch.addmm(layer_weight.fc1_bias_, input.view(-1, self.embed_dim_), layer_weight.fc1_weight_) + ffn1_out = torch.nn.functional.gelu(fc1) + input_shape = input.shape + input = None + ffn2_out = torch.addmm(layer_weight.fc2_bias_, ffn1_out, layer_weight.fc2_weight_) + ffn1_out = None + return ffn2_out.reshape(input_shape) + + def _context_attention(self, input_embding, layer_weight): + input1 = self._att_norm(input_embding, layer_weight) + # import time + # torch.cuda.synchronize() + # a = time.time() + q, k, v = self._get_qkv(input1, layer_weight) + # if self.qk_norm: + # q, k = self._qk_norm(q, k, layer_weight) + # input1 = None + o = self._context_attention_kernel(q, k, v) + # q = None + o = self._get_o(o, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + input_embding.add_(o) + # torch.cuda.synchronize() + # b = time.time() + # print(f"{self.layer_num_} The time is {(b - a) * 1000}", layer_weight.o_weight_.shape) + return + + def _context_ffn(self, input_embdings, layer_weight): + input1 = self._ffn_norm(input_embdings, layer_weight) + ffn_out = self._ffn(input1, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + input_embdings.add_(ffn_out) + return + + def forward(self, input_embdings, layer_weight): + self._context_attention(input_embdings, layer_weight=layer_weight) + self._context_ffn(input_embdings, layer_weight) + return input_embdings diff --git a/lightllm/models/vit/layer_weights/__init__.py b/lightllm/models/vit/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/vit/layer_weights/hf_load_utils.py b/lightllm/models/vit/layer_weights/hf_load_utils.py new file mode 100644 index 000000000..3fa82af8e --- /dev/null +++ b/lightllm/models/vit/layer_weights/hf_load_utils.py @@ -0,0 +1,68 @@ +import torch +import os +import gc +from safetensors import safe_open +import lightllm.utils.petrel_helper as utils + + +def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): + if use_safetensors: + weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + weights = {k: weights.get_tensor(k) for k in weights.keys()} + else: + weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu") + new_weight = {} + for k, v in weights.items(): + if "language_model." in k: + new_weight[k[len("language_model.") :]] = v + else: + new_weight[k] = v + del weights + weights = new_weight + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights + gc.collect() + + +def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + if weight_dict: + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + return + use_safetensors = True + files = utils.PetrelHelper.list(weight_dir, extension="all") + candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files)) + if len(candidate_files) == 0: + use_safetensors = False + candidate_files = list(sorted(filter(lambda x: x.endswith(".bin"), files))) + candidate_files = candidate_files[0:45] + [candidate_files[-1]] + print(candidate_files) + assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." + from functools import partial + from multiprocessing.pool import ThreadPool as Pool + + partial_func = partial( + load_func, + use_safetensors=use_safetensors, + pre_post_layer=pre_post_layer, + transformer_layer_list=transformer_layer_list, + weight_dir=weight_dir, + ) # noqa + worker = int(os.environ.get("LOADWORKER", 1)) + with Pool(worker) as p: + _ = p.map(partial_func, candidate_files) + return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..eb0cffddf --- /dev/null +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,88 @@ +import torch +import numpy as np +import torch.nn.functional as F +from lightllm.common.basemodel import PreAndPostLayerWeight + + +class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, tp_rank, gpud_id, world_size, data_type, network_config, mode): + super().__init__(tp_rank, world_size, data_type, network_config, mode) + self.embed_dim = self.network_config_["hidden_size"] + self.image_size = self.network_config_["image_size"] + self.patch_size = self.network_config_["patch_size"] + self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self.gpu_id_ = gpud_id + return + + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + + def _get_pos_embed(self, H, W): + pos_embed = self.position_embedding[:, 1:, :] + target_dtype = pos_embed.dtype + pos_embed = ( + pos_embed.float() + .reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1) + .permute(0, 3, 1, 2) + ) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + .to(target_dtype) + ) + return pos_embed + + def load_hf_weights(self, weights): + split_indexes = np.linspace(0, self.embed_dim, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "vision_model.embeddings.class_embedding" in weights: + self.class_embedding = self._cuda( + weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] + ) + if "vision_model.embeddings.position_embedding" in weights: + self.position_embedding = self._cuda( + weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] + ) + if "vision_model.embeddings.patch_embedding.weight" in weights: + self.patch_embedding_weight_ = self._cuda( + weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] + ) + if "vision_model.embeddings.patch_embedding.bias" in weights: + self.patch_embedding_bias_ = self._cuda( + weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] + ) + + if "mlp1.0.weight" in weights: + self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) + if "mlp1.0.bias" in weights: + self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) + + split_indexes = np.linspace(0, self.llm_hidden_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + + if "mlp1.1.weight" in weights: + self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() + if "mlp1.1.bias" in weights: + self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) + + if "mlp1.3.weight" in weights: + self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() + if "mlp1.3.bias" in weights: + self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) + + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.class_embedding, + self.position_embedding, + self.patch_embedding_weight_, + self.patch_embedding_bias_, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..fa4737a98 --- /dev/null +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -0,0 +1,211 @@ +import torch +import math +import numpy as np +import torch.nn.functional as F +from lightllm.common.basemodel import TransformerLayerWeight + + +class ViTTransformerLayerWeight(TransformerLayerWeight): + def __init__(self, layer_num, tp_rank, gpu_id, world_size, data_type, network_config, mode=[], quant_cfg=None): + self.padding_hidden_size = network_config["padding_hidden_size"] + self.qk_norm = network_config["qk_normalization"] + self.use_ls = network_config.get("use_ls", False) + self.qkv_bias = network_config.get("qkv_bias", True) + self.layer_norm_eps = network_config.get("layer_norm_eps", 1e-6) + self.norm_type = network_config.get("norm_type", "layer_norm") + self.gpu_id_ = gpu_id + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) + return + + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + + def _try_cat_to(self, source_tensor_names, dest_name, cat_dim, handle_func=None): + if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): + with self.lock: + if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): + assert all( + not getattr(self, name, None).is_cuda for name in source_tensor_names + ), "all not cuda tensor" + tensors = [getattr(self, name, None) for name in source_tensor_names] + ans = torch.cat(tensors, dim=cat_dim) + if handle_func is not None: + ans = handle_func(ans) + else: + ans = self._cuda(ans) + setattr(self, dest_name, ans) + for name in source_tensor_names: + delattr(self, name) + return + + def load_hf_weights(self, weights): + self._load_qkvo_weights(weights) + self._load_ffn_weights(weights) + return + + def post_load(self): + # merge ls + ls1 = self.ls1.to(torch.float64) + self.o_bias_ = (self.o_bias_.to(torch.float64) * ls1).to(self.data_type_) + self.o_weight_ = (self.o_weight_.to(torch.float64) * ls1.reshape(1, -1)).to(self.data_type_) + + ls2 = self.ls2.to(torch.float64) + self.fc2_bias_ = (self.fc2_bias_.to(torch.float64) * ls2).to(self.data_type_) + self.fc2_weight_ = (self.fc2_weight_.to(torch.float64) * ls2.reshape(1, -1)).to(self.data_type_) + del self.ls1 + del self.ls2 + torch.cuda.empty_cache() + + def _transpose(self, ans): + return ans.t().cuda(self.gpu_id_) + + def verify_load(self): + errors = "weights load not ok" + if not self.qk_norm: + self.q_norm_weight_ = torch.ones(1).cuda(self.gpu_id_) + self.k_norm_weight_ = torch.ones(1).cuda(self.gpu_id_) + if not self.use_ls: + self.ls1 = 1.0 + self.ls2 = 1.0 + + weights = [ + self.att_norm_weight_, + self.q_norm_weight_, + self.k_norm_weight_, + # self.q_weight_, + self.o_weight_, + self.o_bias_, + self.ffn_norm_weight_, + self.fc1_weight_, + self.fc1_bias_, + self.fc2_weight_, + self.fc2_bias_, + self.ls1, + self.ls2, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + self.post_load() + return + + def _load_qkvo_weights(self, weights): + # input layernorm params + if f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" in weights: + self.att_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight"]) + if f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" in weights: + self.att_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias"]) + + n_embed = self.network_config_["hidden_size"] + split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + if f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" in weights: + q_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight"] + q_norm_weight_ = F.pad(q_norm_weight_, (0, self.padding_hidden_size)) + self.q_norm_weight_ = self._cuda( + q_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + ) + + if f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" in weights: + k_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight"] + k_norm_weight_ = F.pad(k_norm_weight_, (0, self.padding_hidden_size)) + self.k_norm_weight_ = self._cuda( + k_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + ) + + # q k v weights for llama + if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: + att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] + + att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) + # self.qkv_weight_ = self._cuda(att_qkv_dense_weight).t() + + q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) + self.q_weight_ = q_weight_.reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] + + k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) + self.k_weight_ = k_weight_.reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] + + v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) + self.v_weight_ = v_weight_.reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] + + if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: + att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] + # self.qkv_bias_ = self._cuda(att_qkv_dense_bias) + self.q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + self.k_bias_ = att_qkv_dense_bias[ + n_embed + split_n_embed * self.tp_rank_ : n_embed + split_n_embed * (self.tp_rank_ + 1) + ] + self.v_bias_ = att_qkv_dense_bias[ + n_embed * 2 + split_n_embed * self.tp_rank_ : n_embed * 2 + split_n_embed * (self.tp_rank_ + 1) + ] + + self._try_cat_to(["q_weight_", "k_weight_", "v_weight_"], "qkv_weight_", cat_dim=0, handle_func=self._transpose) + self._try_cat_to(["q_bias_", "k_bias_", "v_bias_"], "qkv_bias_", cat_dim=0) + # attention output dense params + if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" in weights: + o_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight"] + o_weight_ = F.pad(o_weight_, (0, self.padding_hidden_size, 0, 0)) + o_weight_ = o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + # print(o_weight_.shape, o_weight_) + self.o_weight_ = self._cuda(o_weight_).t() + if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" in weights: + o_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias"] + if self.tp_rank_ == 0: + self.o_bias_ = self._cuda(o_bias_) + else: + self.o_bias_ = self._cuda(torch.zeros_like(o_bias_)) + + if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: + ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] + self.ls1 = self._cuda(ls1) + self.use_ls = True + + # self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + return + + def _load_ffn_weights(self, weights): + if f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" in weights: + self.ffn_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight"]) + + if f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" in weights: + self.ffn_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias"]) + + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + + if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" in weights: + fc1_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.fc1_weight_ = self._cuda(fc1_weight_).t() + + if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" in weights: + fc1_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.fc1_bias_ = self._cuda(fc1_bias_) + + if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" in weights: + fc2_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.fc2_weight_ = self._cuda(fc2_weight_).t() + + if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" in weights: + fc2_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias"] + if self.tp_rank_ == 0: + self.fc2_bias_ = self._cuda(fc2_bias_) + else: + self.fc2_bias_ = self._cuda(torch.zeros_like(fc2_bias_)) + + if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: + ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] + self.ls2 = self._cuda(ls2) + + return diff --git a/lightllm/models/vit/model_vit.py b/lightllm/models/vit/model_vit.py new file mode 100644 index 000000000..a616856d7 --- /dev/null +++ b/lightllm/models/vit/model_vit.py @@ -0,0 +1,176 @@ +import os +import json +import torch +from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer +from lightllm.models.vit.layer_infer.post_layer_infer import ViTPostLayerInfer +from lightllm.models.vit.layer_infer.transformer_layer_infer import ViTTransformerLayerInfer +from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight +from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight +from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights +from lightllm.utils.log_utils import init_logger +from lightllm.models.internvl.img_process import load_image +import torchvision.transforms as T +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from PIL import Image +from typing import List, Union +from io import BytesIO +from rpyc.utils.classic import obtain +from lightllm.common.quantization import Quantcfg + + +logger = init_logger(__name__) + + +class VisionTransformer: + + # weight class + pre_and_post_weight_class = ViTPreAndPostLayerWeight + transformer_weight_class = ViTTransformerLayerWeight + + # infer class + pre_layer_infer_class = ViTPreLayerInfer + transformer_layer_infer_class = ViTTransformerLayerInfer + post_layer_infer_class = ViTPostLayerInfer + + def __init__(self, kvargs): + self.tp_rank_ = kvargs["tp_rank"] + self.gpu_id_ = kvargs["gpu_id"] + self.world_size_ = kvargs["world_size"] + self.weight_dir_ = kvargs["weight_dir"] + self.load_way = kvargs.get("load_way", "HF") + self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] + self.weight_dict = kvargs.get("weight_dict", None) + self.data_type = kvargs.get("data_type", "float16") + self.quant_type = kvargs.get("quant_type", None) + self.quant_cfg_path = kvargs.get("quant_cfg", None) + + self._init_datatype() + self._init_config() + self._padding_hidden_size() + self._init_quant() + self._init_weights() + self._init_infer_layer() + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # self.select_layer = self.config["vision_config"]["vit_select_layer"] + self.select_layer = self.config["select_layer"] + self.config["vision_config"]["llm_hidden_size"] = self.config["llm_config"]["hidden_size"] + self.config["vision_config"]["downsample_ratio"] = self.config["downsample_ratio"] + self.config = self.config["vision_config"] + self.layers_num = self.config["num_hidden_layers"] + return + + def _padding_hidden_size(self): + self.config["padding_hidden_size"] = 0 + self.config["padding_head_num"] = self.config["num_attention_heads"] + + head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] + if self.config["num_attention_heads"] % self.world_size_ != 0: + padding_head_num = ( + self.config["num_attention_heads"] + self.world_size_ - 1 + ) // self.world_size_ * self.world_size_ - self.config["num_attention_heads"] + self.config["padding_hidden_size"] = padding_head_num * head_dim + self.config["padding_head_num"] += padding_head_num + return + + def _init_weights(self): + self.pre_post_weight = self.pre_and_post_weight_class( + self.tp_rank_, self.gpu_id_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + ) + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.tp_rank_, + self.gpu_id_, + self.world_size_, + self.data_type, + network_config=self.config, + mode=self.mode, + quant_cfg=self.quant_cfg, + ) + for i in range(self.config["num_hidden_layers"]) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + return + + def _init_quant(self): + self.quant_cfg = Quantcfg(self.config["num_hidden_layers"], self.quant_type, self.quant_cfg_path) + logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + self.post_infer = self.post_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + self.layers_infer = [ + self.transformer_layer_infer_class( + i, tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + for i in range(self.config["num_hidden_layers"]) + ] + return + + def _init_datatype(self): + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + + # @torch.no_grad() + def forward(self, pixel_values): + input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) + for i in range(self.layers_num + self.select_layer + 1): + input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) + input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) + return input_embs + + @torch.no_grad() + def encode(self, image_uuids: List): + img_tensors = [] + valid_ids = [] + valid_id = 0 + uuids = [] + for i, url in enumerate(image_uuids): + if isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + t = load_image(image_data) + img_tensors.append(t) + else: + raise Exception("Unsupport input types: {} for {}".format(type(url), url)) + + cur_num = img_tensors[-1].shape[0] + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + pixel_values = imgs.cuda().to(dtype=self.data_type) + all_img_embeds = self.forward(pixel_values) + return all_img_embeds, uuids, valid_ids + + def cuda(self): + return self + + def load_model(self, weight_dir): + pass diff --git a/lightllm/models/vit/triton_kernel/__init__.py b/lightllm/models/vit/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py new file mode 100644 index 000000000..4a05ff635 --- /dev/null +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -0,0 +1,179 @@ +import torch + +import triton +import triton.language as tl +import math +import torch.nn.functional as F + +TESLA = "Tesla" in torch.cuda.get_device_name(0) + + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + sm_scale, + seq_len, + Out, + q_stride_b, + q_stride_s, + q_stride_h, + q_stride_d, + o_stride_b, + o_stride_s, + o_stride_h, + o_stride_d, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = cur_batch * q_stride_b + cur_head * q_stride_h + offs_m[:, None] * q_stride_s + offs_d[None, :] + q = tl.load(Q + off_q, mask=offs_m[:, None] < seq_len, other=0.0) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + off_k = ( + cur_batch * q_stride_b + + (start_n + offs_n[None, :]) * q_stride_s + + cur_head * q_stride_h + + offs_d[:, None] + ) + k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk += tl.where((start_n + offs_n[None, :]) < seq_len, 0, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.maximum(tl.max(qk, 1), l_i) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + acc_scale = tl.exp(m_i - m_ij) + acc = acc * acc_scale[:, None] + + # update acc + off_v = ( + cur_batch * q_stride_b + + (start_n + offs_n[:, None]) * q_stride_s + + cur_head * q_stride_h + + offs_d[None, :] + ) + v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < seq_len, other=0.0) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + l_i_new = tl.exp(l_i - m_ij) + l_ij + l_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - l_i) + acc = acc * o_scale[:, None] + # initialize pointers to output + off_o = cur_batch * o_stride_b + offs_m[:, None] * o_stride_s + cur_head * o_stride_h + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < seq_len) + return + + @torch.no_grad() + def flash_attention_fwd( + q, + k, + v, + o, + ): + BLOCK = 64 + # shape constraints + batch_size, seq_len, head_num, head_dim = q.shape + + sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数 + grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head, + # grid = (triton.cdiv(seq_len, BLOCK), batch_size, head_num) # batch, head, + num_warps = 4 + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + BLOCK_M=BLOCK, + BLOCK_DMODEL=head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=2, + ) + return + +else: + raise Exception("error triton version!") + + +def torch_att(q, k, v): + head_dim = q.shape[-1] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + scale = head_dim ** -0.5 + attn = (q * scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + out = out.transpose(1, 2).contiguous() + return out + + +def test(): + import torch + import numpy as np + + B, L, H, D = 4, 1025, 7, 128 + dtype = torch.float16 + q = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + + torch_out = torch_att(q, k, v) + import time + + torch.cuda.synchronize() + a = time.time() + for i in range(100): + flash_attention_fwd(q, k, v, o) + # o = torch_att(q, k, v) + torch.cuda.synchronize() + b = time.time() + # print(o.shape, torch_out.shape) + print((b - a) / 100 * 1000) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +# test() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 08c9df17f..465a672b0 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -11,6 +11,7 @@ from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel +from lightllm.models.vit.model_vit import VisionTransformer from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed @@ -32,15 +33,15 @@ def exposed_init_model(self, kvargs): visual_nccl_port = kvargs["visual_nccl_port"] self.vit_rank_id = kvargs["vit_rank_id"] self.cache_client = rpyc.connect("localhost", self.cache_port) + self.data_type = kvargs["data_type"] torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) - if self.vit_tp != 1: - dist.init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{visual_nccl_port}", - rank=self.tp_rank_id, - world_size=self.vit_tp, - ) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", + rank=self.tp_rank_id, + world_size=self.vit_tp, + ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) if self.vit_tp != 1: @@ -55,6 +56,15 @@ def exposed_init_model(self, kvargs): self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": self.model = InternVLVisionModel() + kvargs = { + "tp_rank": self.tp_rank_id, + "gpu_id": visual_gpu_ids[self.vit_rank_id], + "world_size": self.vit_tp, + "weight_dir": weight_dir, + "data_type": self.data_type, + } + print(kvargs) + self.model = VisionTransformer(kvargs) else: raise Exception(f"can not support {self.model_type} now") From 306a9350d37c06831fa8d3666d861dca3f5fbd59 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 10 Dec 2024 20:37:55 +0800 Subject: [PATCH 02/17] vit quant (pre) --- .../layer_weights/meta_weights/base_weight.py | 2 + .../meta_weights/fused_moe_weight.py | 6 +- .../layer_weights/meta_weights/mm_weight.py | 14 +- .../layer_weights/meta_weights/norm_weight.py | 14 +- .../layer_infer/transformer_layer_infer.py | 63 +-- .../layer_weights/transformer_layer_weight.py | 408 +++++++++++------- 6 files changed, 277 insertions(+), 230 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 4a3accc45..dcfbb247c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,3 +1,4 @@ +import os import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_world_size, get_rank @@ -20,6 +21,7 @@ class BaseWeightTpl(BaseWeight): def __init__(self): self.world_size_ = get_world_size() self.tp_rank_ = get_rank() + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) def load_hf_weights(self, weights): pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index cd0f1a5f9..b7605b845 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -1,3 +1,4 @@ +import os import torch from .base_weight import BaseWeight from lightllm.utils.dist_utils import get_world_size, get_rank @@ -21,6 +22,7 @@ def __init__( self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_rank() + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts self.w2_list = [None] * self.n_routed_experts @@ -113,10 +115,10 @@ def load_hf_weights(self, weights): self._fuse() def _cuda(self, cpu_tensor): - if self.tp_rank_ is None: + if self.device_id_ is None: return cpu_tensor.contiguous().to(self.data_type_).cuda() else: - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) def verify_load(self): return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index d188d5511..f271082a8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -31,9 +31,9 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True): def _post_load_weights(self): if self.quant_method is not None: - self.weight = self.quant_method.quantize(self.weight.cuda(self.tp_rank_)) + self.weight = self.quant_method.quantize(self.weight.cuda(self.device_id_)) return - self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + self.weight = self.weight.transpose(0, 1).cuda(self.device_id_) class MMWeight(MMWeightTpl): @@ -65,7 +65,7 @@ def load_hf_weights(self, weights): self.weight = weight[self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] - self.bias = bias.cuda(self.tp_rank_) + self.bias = bias.cuda(self.device_id_) if weight is None: return self._post_load_weights() @@ -90,7 +90,7 @@ def load_hf_weights(self, weights): self.weight = weight[:, self.start : self.end] if self.bias_name in weights: bias = weights[self.bias_name] - self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.tp_rank_) + self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_) if weight is None: return self._post_load_weights() @@ -133,7 +133,7 @@ def _fuse(self): self._post_load_weights() if self.has_bias: if self.bias is None and all(b is not None for b in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(self.tp_rank_) + self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_) return self def load_hf_weights(self, weights): @@ -200,7 +200,7 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True): return torch.addbmm(self.bias, input_tensor, self.weight, out=out) def _post_load_weights(self): - self.weight = self.weight.cuda(self.tp_rank_) + self.weight = self.weight.cuda(self.device_id_) class BMMWeight(BMMWeightTpl): @@ -247,4 +247,4 @@ def __init__( super().__init__(weight_name, data_type, split_n_embed, bias_name) def _post_load_weights(self): - self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + self.weight = self.weight.transpose(0, 1).cuda(self.device_id_) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 029d723c9..3cf88d1d8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -12,9 +12,9 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.tp_rank_) + self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_) if self.bias_name in weights: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.tp_rank_) + self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_) def verify_load(self): load_ok = True @@ -32,7 +32,7 @@ def __init__(self, weight_name, data_type, bias_name=None): def load_hf_weights(self, weights): if self.weight_name in weights: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.tp_rank_) + self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_) class TpNormWeight(NormWeight): @@ -41,10 +41,10 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): self.split_n_embed = split_n_embed def load_hf_weights(self, weights): - start = self.offset + self.split_n_embed * self.tp_rank_ - end = self.offset + self.split_n_embed * (self.tp_rank_ + 1) + start = self.offset + self.split_n_embed * self.device_id_ + end = self.offset + self.split_n_embed * (self.device_id_ + 1) if self.weight_name in weights: - self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.tp_rank_) + self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_) if self.bias_name in weights: - self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.tp_rank_) + self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_) diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 5ffcbd841..5a28b980d 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -47,105 +47,74 @@ def tp_norm(self, input, weight): def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: if layer_weight.norm_type == "rms_norm": - b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_, eps=self.eps_) + b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_) else: b = torch.nn.functional.layer_norm( input, normalized_shape=[1024], - weight=layer_weight.att_norm_weight_, - bias=layer_weight.att_norm_bias_, + weight=layer_weight.att_norm_weight_.weight, + bias=layer_weight.att_norm_weight_.bias, eps=layer_weight.layer_norm_eps, ) - # b = torch.empty_like(input) - # rms_norm(b, input, layer_weight.att_norm_weight_, self.eps_ , 4) return b def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - # return self.norm(input, layer_weight.ffn_norm_weight_) if layer_weight.norm_type == "rms_norm": - return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_, eps=self.eps_) + return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_) else: return torch.nn.functional.layer_norm( input, normalized_shape=[1024], - weight=layer_weight.ffn_norm_weight_, - bias=layer_weight.ffn_norm_bias_, + weight=layer_weight.ffn_norm_weight_.weight, + bias=layer_weight.ffn_norm_weight_.bias, eps=layer_weight.layer_norm_eps, ) def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - q_norm = self.tp_norm(q, layer_weight.q_norm_weight_) - k_norm = self.tp_norm(k, layer_weight.k_norm_weight_) - # import numpy as np - # q_norm = rmsnorm_forward(q, weight=layer_weight.q_norm_weight_, eps=self.eps_) - # k_norm = rmsnorm_forward(k, weight=layer_weight.k_norm_weight_, eps=self.eps_) + q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight) + k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight) return q_norm, k_norm def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: batch_size = input.shape[0] seq_len = input.shape[1] - if layer_weight.qkv_bias: - qkv = torch.addmm(layer_weight.qkv_bias_, input.view(-1, self.embed_dim_), layer_weight.qkv_weight_).view( - batch_size, seq_len, 3, -1, self.head_dim_ - ) - q, k, v = qkv.unbind(2) - else: - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_).view(batch_size, seq_len, -1) - k = torch.mm(input.view(-1, self.embed_dim_), layer_weight.k_weight_).view(batch_size, seq_len, -1) - v = torch.mm(input.view(-1, self.embed_dim_), layer_weight.v_weight_).view(batch_size, seq_len, -1) + qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_)) + qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_) + q, k, v = qkv.unbind(2) return q, k, v def _context_attention_kernel(self, q, k, v) -> torch.Tensor: out = torch.empty_like(q) batch_size = q.shape[0] seq_len = q.shape[1] - # import time - # torch.cuda.synchronize() - # a = time.time() - # for i in range(100): flash_attention_fwd(q, k, v, out) - # torch.cuda.synchronize() - # b = time.time() - # print(f"{self.layer_num_} The time is {(b - a) * 10}") return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: batch_size = input.shape[0] seq_len = input.shape[1] - o_tensor = torch.addmm( - layer_weight.o_bias_, - input.view(-1, self.tp_padding_head_num * self.head_dim_), - layer_weight.o_weight_, - ) + o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_padding_head_num * self.head_dim_)) return o_tensor.reshape((batch_size, seq_len, -1)) def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - fc1 = torch.addmm(layer_weight.fc1_bias_, input.view(-1, self.embed_dim_), layer_weight.fc1_weight_) + fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_)) ffn1_out = torch.nn.functional.gelu(fc1) input_shape = input.shape input = None - ffn2_out = torch.addmm(layer_weight.fc2_bias_, ffn1_out, layer_weight.fc2_weight_) + ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out) ffn1_out = None return ffn2_out.reshape(input_shape) def _context_attention(self, input_embding, layer_weight): input1 = self._att_norm(input_embding, layer_weight) - # import time - # torch.cuda.synchronize() - # a = time.time() q, k, v = self._get_qkv(input1, layer_weight) - # if self.qk_norm: - # q, k = self._qk_norm(q, k, layer_weight) - # input1 = None + if layer_weight.qk_norm: + q, k = self._qk_norm(q, k, layer_weight) o = self._context_attention_kernel(q, k, v) - # q = None o = self._get_o(o, layer_weight) if self.world_size_ > 1: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o) - # torch.cuda.synchronize() - # b = time.time() - # print(f"{self.layer_num_} The time is {(b - a) * 1000}", layer_weight.o_weight_.shape) return def _context_ffn(self, input_embdings, layer_weight): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index fa4737a98..e602a5b11 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -3,209 +3,283 @@ import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, MultiROWMMWeight class ViTTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, gpu_id, world_size, data_type, network_config, mode=[], quant_cfg=None): - self.padding_hidden_size = network_config["padding_hidden_size"] - self.qk_norm = network_config["qk_normalization"] - self.use_ls = network_config.get("use_ls", False) - self.qkv_bias = network_config.get("qkv_bias", True) - self.layer_norm_eps = network_config.get("layer_norm_eps", 1e-6) - self.norm_type = network_config.get("norm_type", "layer_norm") - self.gpu_id_ = gpu_id super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) + self.gpu_id_ = gpu_id return - def _cuda(self, cpu_tensor): - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) - - def _try_cat_to(self, source_tensor_names, dest_name, cat_dim, handle_func=None): - if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): - with self.lock: - if all(hasattr(self, src_name) for src_name in source_tensor_names) and not hasattr(self, dest_name): - assert all( - not getattr(self, name, None).is_cuda for name in source_tensor_names - ), "all not cuda tensor" - tensors = [getattr(self, name, None) for name in source_tensor_names] - ans = torch.cat(tensors, dim=cat_dim) - if handle_func is not None: - ans = handle_func(ans) - else: - ans = self._cuda(ans) - setattr(self, dest_name, ans) - for name in source_tensor_names: - delattr(self, name) - return - - def load_hf_weights(self, weights): - self._load_qkvo_weights(weights) - self._load_ffn_weights(weights) - return - - def post_load(self): - # merge ls - ls1 = self.ls1.to(torch.float64) - self.o_bias_ = (self.o_bias_.to(torch.float64) * ls1).to(self.data_type_) - self.o_weight_ = (self.o_weight_.to(torch.float64) * ls1.reshape(1, -1)).to(self.data_type_) - - ls2 = self.ls2.to(torch.float64) - self.fc2_bias_ = (self.fc2_bias_.to(torch.float64) * ls2).to(self.data_type_) - self.fc2_weight_ = (self.fc2_weight_.to(torch.float64) * ls2.reshape(1, -1)).to(self.data_type_) - del self.ls1 - del self.ls2 - torch.cuda.empty_cache() - - def _transpose(self, ans): - return ans.t().cuda(self.gpu_id_) - - def verify_load(self): - errors = "weights load not ok" - if not self.qk_norm: - self.q_norm_weight_ = torch.ones(1).cuda(self.gpu_id_) - self.k_norm_weight_ = torch.ones(1).cuda(self.gpu_id_) - if not self.use_ls: - self.ls1 = 1.0 - self.ls2 = 1.0 - - weights = [ - self.att_norm_weight_, - self.q_norm_weight_, - self.k_norm_weight_, - # self.q_weight_, - self.o_weight_, - self.o_bias_, - self.ffn_norm_weight_, - self.fc1_weight_, - self.fc1_bias_, - self.fc2_weight_, - self.fc2_bias_, - self.ls1, - self.ls2, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - self.post_load() - return - - def _load_qkvo_weights(self, weights): - # input layernorm params - if f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" in weights: - self.att_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight"]) - if f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" in weights: - self.att_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias"]) - + def _parse_config(self): + self.padding_hidden_size = self.network_config_["padding_hidden_size"] + self.qk_norm = self.network_config_["qk_normalization"] + self.use_ls = self.network_config_.get("use_ls", False) + self.qkv_bias = self.network_config_.get("qkv_bias", True) + self.layer_norm_eps = self.network_config_.get("layer_norm_eps", 1e-6) + self.norm_type = self.network_config_.get("norm_type", "layer_norm") + + def _init_weight_names(self): + self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" + + self._q_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q.weight" + self._k_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k.weight" + self._v_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.v.weight" + + if self.qkv_bias: + self._q_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q.bias" + self._k_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k.bias" + self._v_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.v.bias" + else: + self._q_bias_name = None + self._k_bias_name = None + self._v_bias_name = None + + self._o_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" + self._o_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" + + self.fc1_weight_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" + self.fc1_bias_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" + self.fc2_weight_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" + self.fc2_bias_name_ = f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" + + self._ls1_name = f"vision_model.encoder.layers.{self.layer_num_}.ls1" + self._ls2_name = f"vision_model.encoder.layers.{self.layer_num_}.ls2" + + self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" + self._ffn_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" + + if self.norm_type == "layer_norm": + self._att_norm_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" + self._ffn_norm_bias_name = f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" + else: + self._att_norm_bias_name = None + self._ffn_norm_bias_name = None + + if self.qk_norm: + self._q_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" + self._k_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" + + def _init_weight(self): + self._init_qkv() + self._init_o() + self._init_ffn() + self._init_norm() + + def _init_qkv(self): n_embed = self.network_config_["hidden_size"] - split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ - if f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" in weights: - q_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight"] - q_norm_weight_ = F.pad(q_norm_weight_, (0, self.padding_hidden_size)) - self.q_norm_weight_ = self._cuda( - q_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - ) + qkv_split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.qkv_proj = MultiROWMMWeight( + [self._q_weight_name, self._k_weight_name, self._v_weight_name], + self.data_type_, + qkv_split_n_embed, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + ) + + def _init_o(self): + n_embed = self.network_config_["hidden_size"] + o_split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.o_proj = COLMMWeight(self._o_weight_name, self.data_type_, o_split_n_embed, bias_name=self._o_bias_name) - if f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" in weights: - k_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight"] - k_norm_weight_ = F.pad(k_norm_weight_, (0, self.padding_hidden_size)) - self.k_norm_weight_ = self._cuda( - k_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + def _init_ffn(self): + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + self.ffn_1_proj_ = ROWMMWeight( + self.fc1_weight_name_, + self.data_type_, + split_inter_size, + bias_name=self.fc1_bias_name_, + ) + self.ffn_2_proj_ = COLMMWeight( + self.fc2_weight_name_, self.data_type_, split_inter_size, bias_name=self.fc2_bias_name_ + ) + + def _init_norm(self): + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + if self.qk_norm: + self.q_norm_weight_ = NormWeight( + self._q_norm_weight_name, self.data_type_, bias_name=self._q_norm_bias_name + ) + self.k_norm_weight_ = NormWeight( + self._k_norm_weight_name, self.data_type_, bias_name=self._q_norm_bias_name ) - # q k v weights for llama + def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: + n_embed = self.network_config_["hidden_size"] + split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] - att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) - # self.qkv_weight_ = self._cuda(att_qkv_dense_weight).t() - q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) - self.q_weight_ = q_weight_.reshape(-1, n_embed)[ + q_weight_ = q_weight_.reshape(-1, n_embed)[ split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : ] - k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) - self.k_weight_ = k_weight_.reshape(-1, n_embed)[ + k_weight_ = k_weight_.reshape(-1, n_embed)[ split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : ] - v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) - self.v_weight_ = v_weight_.reshape(-1, n_embed)[ + v_weight_ = v_weight_.reshape(-1, n_embed)[ split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : ] + del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] + weights[self._q_weight_name] = q_weight_ + weights[self._k_weight_name] = k_weight_ + weights[self._v_weight_name] = v_weight_ if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: + n_embed = self.network_config_["hidden_size"] + split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - # self.qkv_bias_ = self._cuda(att_qkv_dense_bias) - self.q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - self.k_bias_ = att_qkv_dense_bias[ + att_qkv_dense_bias = F.pad(att_qkv_dense_bias, (0, self.padding_hidden_size)) + q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + k_bias_ = att_qkv_dense_bias[ n_embed + split_n_embed * self.tp_rank_ : n_embed + split_n_embed * (self.tp_rank_ + 1) ] - self.v_bias_ = att_qkv_dense_bias[ + v_bias_ = att_qkv_dense_bias[ n_embed * 2 + split_n_embed * self.tp_rank_ : n_embed * 2 + split_n_embed * (self.tp_rank_ + 1) ] - - self._try_cat_to(["q_weight_", "k_weight_", "v_weight_"], "qkv_weight_", cat_dim=0, handle_func=self._transpose) - self._try_cat_to(["q_bias_", "k_bias_", "v_bias_"], "qkv_bias_", cat_dim=0) - # attention output dense params - if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" in weights: - o_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight"] - o_weight_ = F.pad(o_weight_, (0, self.padding_hidden_size, 0, 0)) - o_weight_ = o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - # print(o_weight_.shape, o_weight_) - self.o_weight_ = self._cuda(o_weight_).t() - if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" in weights: - o_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias"] - if self.tp_rank_ == 0: - self.o_bias_ = self._cuda(o_bias_) - else: - self.o_bias_ = self._cuda(torch.zeros_like(o_bias_)) + weights[self._q_bias_name] = q_bias_ + weights[self._k_bias_name] = k_bias_ + weights[self._v_bias_name] = v_bias_ + del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] self.ls1 = self._cuda(ls1) - self.use_ls = True - - # self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) - - return - - def _load_ffn_weights(self, weights): - if f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" in weights: - self.ffn_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight"]) - - if f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" in weights: - self.ffn_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias"]) - - inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.world_size_ - - if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" in weights: - fc1_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight"][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : - ] - self.fc1_weight_ = self._cuda(fc1_weight_).t() - - if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" in weights: - fc1_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias"][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) - ] - self.fc1_bias_ = self._cuda(fc1_bias_) - - if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" in weights: - fc2_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight"][ - :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) - ] - self.fc2_weight_ = self._cuda(fc2_weight_).t() - - if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" in weights: - fc2_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias"] - if self.tp_rank_ == 0: - self.fc2_bias_ = self._cuda(fc2_bias_) - else: - self.fc2_bias_ = self._cuda(torch.zeros_like(fc2_bias_)) if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] self.ls2 = self._cuda(ls2) + self.use_ls = True - return + return super().load_hf_weights(weights) + + # def _load_qkvo_weights(self, weights): + # # input layernorm params + # if f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" in weights: + # self.att_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight"]) + # if f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" in weights: + # self.att_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias"]) + + # n_embed = self.network_config_["hidden_size"] + # split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" in weights: + # q_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight"] + # q_norm_weight_ = F.pad(q_norm_weight_, (0, self.padding_hidden_size)) + # self.q_norm_weight_ = self._cuda( + # q_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + # ) + + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" in weights: + # k_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight"] + # k_norm_weight_ = F.pad(k_norm_weight_, (0, self.padding_hidden_size)) + # self.k_norm_weight_ = self._cuda( + # k_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + # ) + + # # q k v weights for llama + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: + # att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] + + # att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) + # # self.qkv_weight_ = self._cuda(att_qkv_dense_weight).t() + + # q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) + # self.q_weight_ = q_weight_.reshape(-1, n_embed)[ + # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + # ] + + # k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) + # self.k_weight_ = k_weight_.reshape(-1, n_embed)[ + # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + # ] + + # v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) + # self.v_weight_ = v_weight_.reshape(-1, n_embed)[ + # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + # ] + + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: + # att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] + # # self.qkv_bias_ = self._cuda(att_qkv_dense_bias) + # self.q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + # self.k_bias_ = att_qkv_dense_bias[ + # n_embed + split_n_embed * self.tp_rank_ : n_embed + split_n_embed * (self.tp_rank_ + 1) + # ] + # self.v_bias_ = att_qkv_dense_bias[ + # n_embed * 2 + split_n_embed * self.tp_rank_ : n_embed * 2 + split_n_embed * (self.tp_rank_ + 1) + # ] + + # self._try_cat_to(["q_weight_", "k_weight_", "v_weight_"], "qkv_weight_", cat_dim=0, + # handle_func=self._transpose) + # self._try_cat_to(["q_bias_", "k_bias_", "v_bias_"], "qkv_bias_", cat_dim=0) + # # attention output dense params + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" in weights: + # o_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight"] + # o_weight_ = F.pad(o_weight_, (0, self.padding_hidden_size, 0, 0)) + # o_weight_ = o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] + # # print(o_weight_.shape, o_weight_) + # self.o_weight_ = self._cuda(o_weight_).t() + # if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" in weights: + # o_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias"] + # if self.tp_rank_ == 0: + # self.o_bias_ = self._cuda(o_bias_) + # else: + # self.o_bias_ = self._cuda(torch.zeros_like(o_bias_)) + + # if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: + # ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] + # self.ls1 = self._cuda(ls1) + # self.use_ls = True + + # # self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + # return + + # def _load_ffn_weights(self, weights): + # if f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" in weights: + # self.ffn_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight"]) + + # if f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" in weights: + # self.ffn_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias"]) + + # inter_size = self.network_config_["intermediate_size"] + # split_inter_size = inter_size // self.world_size_ + + # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" in weights: + # fc1_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight"][ + # split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + # ] + # self.fc1_weight_ = self._cuda(fc1_weight_).t() + + # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" in weights: + # fc1_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias"][ + # split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + # ] + # self.fc1_bias_ = self._cuda(fc1_bias_) + + # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" in weights: + # fc2_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight"][ + # :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + # ] + # self.fc2_weight_ = self._cuda(fc2_weight_).t() + + # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" in weights: + # fc2_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias"] + # if self.tp_rank_ == 0: + # self.fc2_bias_ = self._cuda(fc2_bias_) + # else: + # self.fc2_bias_ = self._cuda(torch.zeros_like(fc2_bias_)) + + # if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: + # ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] + # self.ls2 = self._cuda(ls2) + + # return From d0e59925acd58449b2936d20c3c2feb3d3766ada Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 15:41:27 +0800 Subject: [PATCH 03/17] vit 6b tp --- .../layer_weights/meta_weights/norm_weight.py | 4 +- .../models/vit/layer_infer/pre_layer_infer.py | 8 +- .../layer_infer/transformer_layer_infer.py | 20 +- .../pre_and_post_layer_weight.py | 5 +- .../layer_weights/transformer_layer_weight.py | 184 ++++-------------- .../models/vit/{model_vit.py => model.py} | 6 +- lightllm/server/api_cli.py | 2 +- .../visualserver/model_infer/model_rpc.py | 7 +- 8 files changed, 62 insertions(+), 174 deletions(-) rename lightllm/models/vit/{model_vit.py => model.py} (95%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 3cf88d1d8..5ba48516c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -41,8 +41,8 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): self.split_n_embed = split_n_embed def load_hf_weights(self, weights): - start = self.offset + self.split_n_embed * self.device_id_ - end = self.offset + self.split_n_embed * (self.device_id_ + 1) + start = self.split_n_embed * self.device_id_ + end = self.split_n_embed * (self.device_id_ + 1) if self.weight_name in weights: self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_) diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index c9134e3c2..93de2fc0b 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -5,14 +5,16 @@ import numpy as np from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight -from lightllm.common.basemodel import PreLayerInferTpl -class ViTPreLayerInfer(PreLayerInferTpl): +class ViTPreLayerInfer: """ """ def __init__(self, tp_rank, world_size, network_config, mode): - super().__init__(tp_rank, world_size, network_config, mode) + self.tp_rank_ = tp_rank + self.world_size_ = world_size + self.network_config_ = network_config + self.mode = mode return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 5a28b980d..fe030f321 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -37,13 +37,17 @@ def norm(self, input, weight): return weight * input.to(input_dtype) def tp_norm(self, input, weight): + input_shape = input.shape + input = input.view(-1, self.tp_padding_head_num * self.head_dim_) input_dtype = input.dtype input = input.to(torch.float32) tp_variance = input.pow(2).sum(-1, keepdim=True) dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) variance = tp_variance / self.embed_dim_ input = input * torch.rsqrt(variance + self.eps_) - return weight * input.to(input_dtype) + out = weight * input.to(input_dtype) + out = out.reshape(input_shape) + return out def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: if layer_weight.norm_type == "rms_norm": @@ -78,7 +82,7 @@ def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tenso def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: batch_size = input.shape[0] seq_len = input.shape[1] - qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_)) + qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_) q, k, v = qkv.unbind(2) return q, k, v @@ -93,15 +97,21 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor: def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: batch_size = input.shape[0] seq_len = input.shape[1] - o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_padding_head_num * self.head_dim_)) + o_tensor = layer_weight.o_proj.mm( + input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=False + ) + if layer_weight.use_ls: + o_tensor *= layer_weight.ls1 return o_tensor.reshape((batch_size, seq_len, -1)) def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_)) + fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) ffn1_out = torch.nn.functional.gelu(fc1) input_shape = input.shape input = None - ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out) + ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False) + if layer_weight.use_ls: + ffn2_out *= layer_weight.ls2 ffn1_out = None return ffn2_out.reshape(input_shape) diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index eb0cffddf..f56a884a6 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import os import torch import numpy as np import torch.nn.functional as F @@ -5,13 +6,13 @@ class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, tp_rank, gpud_id, world_size, data_type, network_config, mode): + def __init__(self, tp_rank, world_size, data_type, network_config, mode): super().__init__(tp_rank, world_size, data_type, network_config, mode) self.embed_dim = self.network_config_["hidden_size"] self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] - self.gpu_id_ = gpud_id + self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index e602a5b11..2676ff3ed 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -1,17 +1,28 @@ +import os import torch import math import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, MultiROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + NormWeight, + MultiROWMMWeight, + TpNormWeight, +) class ViTTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, tp_rank, gpu_id, world_size, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) - self.gpu_id_ = gpu_id + self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) + return + def _cuda(self, cpu_tensor): + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + def _parse_config(self): self.padding_hidden_size = self.network_config_["padding_hidden_size"] self.qk_norm = self.network_config_["qk_normalization"] @@ -60,6 +71,8 @@ def _init_weight_names(self): if self.qk_norm: self._q_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" self._k_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" + self._q_norm_bias_name = None + self._k_norm_bias_name = None def _init_weight(self): self._init_qkv() @@ -103,53 +116,45 @@ def _init_norm(self): self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name ) if self.qk_norm: - self.q_norm_weight_ = NormWeight( - self._q_norm_weight_name, self.data_type_, bias_name=self._q_norm_bias_name - ) - self.k_norm_weight_ = NormWeight( - self._k_norm_weight_name, self.data_type_, bias_name=self._q_norm_bias_name - ) + n_embed = self.network_config_["hidden_size"] + split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ + self.q_norm_weight_ = TpNormWeight(self._q_norm_weight_name, self.data_type_, split_n_embed) + self.k_norm_weight_ = TpNormWeight(self._k_norm_weight_name, self.data_type_, split_n_embed) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: n_embed = self.network_config_["hidden_size"] - split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) - q_weight_ = q_weight_.reshape(-1, n_embed)[ - split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - ] k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) - k_weight_ = k_weight_.reshape(-1, n_embed)[ - split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - ] v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) - v_weight_ = v_weight_.reshape(-1, n_embed)[ - split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - ] del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] weights[self._q_weight_name] = q_weight_ weights[self._k_weight_name] = k_weight_ weights[self._v_weight_name] = v_weight_ + if self._o_weight_name in weights: + weights[self._o_weight_name] = F.pad(weights[self._o_weight_name], (0, self.padding_hidden_size, 0, 0)) + if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: n_embed = self.network_config_["hidden_size"] - split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - att_qkv_dense_bias = F.pad(att_qkv_dense_bias, (0, self.padding_hidden_size)) - q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - k_bias_ = att_qkv_dense_bias[ - n_embed + split_n_embed * self.tp_rank_ : n_embed + split_n_embed * (self.tp_rank_ + 1) - ] - v_bias_ = att_qkv_dense_bias[ - n_embed * 2 + split_n_embed * self.tp_rank_ : n_embed * 2 + split_n_embed * (self.tp_rank_ + 1) - ] + att_qkv_dense_bias = F.pad(att_qkv_dense_bias, (0, self.padding_hidden_size)).reshape(3, -1) + q_bias_ = att_qkv_dense_bias[0] + k_bias_ = att_qkv_dense_bias[1] + v_bias_ = att_qkv_dense_bias[2] weights[self._q_bias_name] = q_bias_ weights[self._k_bias_name] = k_bias_ weights[self._v_bias_name] = v_bias_ del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] + if self._q_norm_weight_name in weights: + weights[self._q_norm_weight_name] = F.pad(weights[self._q_norm_weight_name], (0, self.padding_hidden_size)) + + if self._k_norm_weight_name in weights: + weights[self._k_norm_weight_name] = F.pad(weights[self._k_norm_weight_name], (0, self.padding_hidden_size)) + if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] self.ls1 = self._cuda(ls1) @@ -160,126 +165,3 @@ def load_hf_weights(self, weights): self.use_ls = True return super().load_hf_weights(weights) - - # def _load_qkvo_weights(self, weights): - # # input layernorm params - # if f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" in weights: - # self.att_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight"]) - # if f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias" in weights: - # self.att_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm1.bias"]) - - # n_embed = self.network_config_["hidden_size"] - # split_n_embed = (n_embed + self.padding_hidden_size) // self.world_size_ - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight" in weights: - # q_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.q_norm.weight"] - # q_norm_weight_ = F.pad(q_norm_weight_, (0, self.padding_hidden_size)) - # self.q_norm_weight_ = self._cuda( - # q_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - # ) - - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight" in weights: - # k_norm_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.k_norm.weight"] - # k_norm_weight_ = F.pad(k_norm_weight_, (0, self.padding_hidden_size)) - # self.k_norm_weight_ = self._cuda( - # k_norm_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - # ) - - # # q k v weights for llama - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: - # att_qkv_dense_weight = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight"] - - # att_qkv_dense_weight = att_qkv_dense_weight.reshape(3, n_embed, -1) - # # self.qkv_weight_ = self._cuda(att_qkv_dense_weight).t() - - # q_weight_ = F.pad(att_qkv_dense_weight[0, :, :], (0, 0, 0, self.padding_hidden_size)) - # self.q_weight_ = q_weight_.reshape(-1, n_embed)[ - # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - # ] - - # k_weight_ = F.pad(att_qkv_dense_weight[1, :, :], (0, 0, 0, self.padding_hidden_size)) - # self.k_weight_ = k_weight_.reshape(-1, n_embed)[ - # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - # ] - - # v_weight_ = F.pad(att_qkv_dense_weight[2, :, :], (0, 0, 0, self.padding_hidden_size)) - # self.v_weight_ = v_weight_.reshape(-1, n_embed)[ - # split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : - # ] - - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias" in weights: - # att_qkv_dense_bias = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - # # self.qkv_bias_ = self._cuda(att_qkv_dense_bias) - # self.q_bias_ = att_qkv_dense_bias[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - # self.k_bias_ = att_qkv_dense_bias[ - # n_embed + split_n_embed * self.tp_rank_ : n_embed + split_n_embed * (self.tp_rank_ + 1) - # ] - # self.v_bias_ = att_qkv_dense_bias[ - # n_embed * 2 + split_n_embed * self.tp_rank_ : n_embed * 2 + split_n_embed * (self.tp_rank_ + 1) - # ] - - # self._try_cat_to(["q_weight_", "k_weight_", "v_weight_"], "qkv_weight_", cat_dim=0, - # handle_func=self._transpose) - # self._try_cat_to(["q_bias_", "k_bias_", "v_bias_"], "qkv_bias_", cat_dim=0) - # # attention output dense params - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight" in weights: - # o_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.weight"] - # o_weight_ = F.pad(o_weight_, (0, self.padding_hidden_size, 0, 0)) - # o_weight_ = o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - # # print(o_weight_.shape, o_weight_) - # self.o_weight_ = self._cuda(o_weight_).t() - # if f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias" in weights: - # o_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.proj.bias"] - # if self.tp_rank_ == 0: - # self.o_bias_ = self._cuda(o_bias_) - # else: - # self.o_bias_ = self._cuda(torch.zeros_like(o_bias_)) - - # if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: - # ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] - # self.ls1 = self._cuda(ls1) - # self.use_ls = True - - # # self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) - - # return - - # def _load_ffn_weights(self, weights): - # if f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight" in weights: - # self.ffn_norm_weight_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.weight"]) - - # if f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias" in weights: - # self.ffn_norm_bias_ = self._cuda(weights[f"vision_model.encoder.layers.{self.layer_num_}.norm2.bias"]) - - # inter_size = self.network_config_["intermediate_size"] - # split_inter_size = inter_size // self.world_size_ - - # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight" in weights: - # fc1_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.weight"][ - # split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : - # ] - # self.fc1_weight_ = self._cuda(fc1_weight_).t() - - # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias" in weights: - # fc1_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc1.bias"][ - # split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) - # ] - # self.fc1_bias_ = self._cuda(fc1_bias_) - - # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight" in weights: - # fc2_weight_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.weight"][ - # :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) - # ] - # self.fc2_weight_ = self._cuda(fc2_weight_).t() - - # if f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias" in weights: - # fc2_bias_ = weights[f"vision_model.encoder.layers.{self.layer_num_}.mlp.fc2.bias"] - # if self.tp_rank_ == 0: - # self.fc2_bias_ = self._cuda(fc2_bias_) - # else: - # self.fc2_bias_ = self._cuda(torch.zeros_like(fc2_bias_)) - - # if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: - # ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] - # self.ls2 = self._cuda(ls2) - - # return diff --git a/lightllm/models/vit/model_vit.py b/lightllm/models/vit/model.py similarity index 95% rename from lightllm/models/vit/model_vit.py rename to lightllm/models/vit/model.py index a616856d7..56fdfda80 100644 --- a/lightllm/models/vit/model_vit.py +++ b/lightllm/models/vit/model.py @@ -34,7 +34,6 @@ class VisionTransformer: def __init__(self, kvargs): self.tp_rank_ = kvargs["tp_rank"] - self.gpu_id_ = kvargs["gpu_id"] self.world_size_ = kvargs["world_size"] self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") @@ -55,7 +54,6 @@ def __init__(self, kvargs): def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) - # self.select_layer = self.config["vision_config"]["vit_select_layer"] self.select_layer = self.config["select_layer"] self.config["vision_config"]["llm_hidden_size"] = self.config["llm_config"]["hidden_size"] self.config["vision_config"]["downsample_ratio"] = self.config["downsample_ratio"] @@ -78,13 +76,12 @@ def _padding_hidden_size(self): def _init_weights(self): self.pre_post_weight = self.pre_and_post_weight_class( - self.tp_rank_, self.gpu_id_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode ) self.trans_layers_weight = [ self.transformer_weight_class( i, self.tp_rank_, - self.gpu_id_, self.world_size_, self.data_type, network_config=self.config, @@ -133,7 +130,6 @@ def _init_datatype(self): else: raise ValueError(f"Unsupport datatype {self.data_type}!") - # @torch.no_grad() def forward(self, pixel_values): input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b2e588628..fea7d371d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -197,7 +197,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") parser.add_argument( - "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch" + "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2" diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 465a672b0..110e1abfe 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -11,7 +11,7 @@ from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel -from lightllm.models.vit.model_vit import VisionTransformer +from lightllm.models.vit.model import VisionTransformer from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed @@ -43,9 +43,8 @@ def exposed_init_model(self, kvargs): world_size=self.vit_tp, ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + os.environ["CURRENT_DEVICE_ID"] = str(visual_gpu_ids[self.vit_rank_id]) - if self.vit_tp != 1: - raise ValueError(f"ERROR: Not support vit_tp value: {self.vit_tp}") try: self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -55,10 +54,8 @@ def exposed_init_model(self, kvargs): elif self.model_type == "llava": self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": - self.model = InternVLVisionModel() kvargs = { "tp_rank": self.tp_rank_id, - "gpu_id": visual_gpu_ids[self.vit_rank_id], "world_size": self.vit_tp, "weight_dir": weight_dir, "data_type": self.data_type, From ecb73a38c129cb35c1390a1bc054a494901c0045 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 15:59:58 +0800 Subject: [PATCH 04/17] custom pre_process --- lightllm/models/vit/__init__.py | 21 ++++++++++ lightllm/models/vit/layer_weights/__init__.py | 38 +++++++++++++++++++ lightllm/models/vit/model.py | 5 ++- .../visualserver/model_infer/model_rpc.py | 1 - 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/lightllm/models/vit/__init__.py b/lightllm/models/vit/__init__.py index 8fac3743c..cbe40d195 100644 --- a/lightllm/models/vit/__init__.py +++ b/lightllm/models/vit/__init__.py @@ -1,4 +1,25 @@ import os +import importlib.util from lightllm.utils.log_utils import init_logger +from lightllm.models.internvl.img_process import load_image as default_load_image logger = init_logger(__name__) + + +def get_load_image_func(weight_dir): + global load_image + pre_process_path = os.path.join(weight_dir, "pre_process.py") + if os.path.exists(pre_process_path): + logger.info(f"Found pre_process.py in {weight_dir}, attempting to load load_image from it.") + spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) + pre_process = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pre_process) + if hasattr(pre_process, "load_image"): + logger.info("load_image function replaced by the one in pre_process.py.") + return pre_process.load_image + else: + logger.info("load_image function not found in pre_process.py.") + else: + logger.info(f"pre_process.py not found in {weight_dir}, using default load_image.") + + return default_load_image diff --git a/lightllm/models/vit/layer_weights/__init__.py b/lightllm/models/vit/layer_weights/__init__.py index e69de29bb..08bcd04cf 100644 --- a/lightllm/models/vit/layer_weights/__init__.py +++ b/lightllm/models/vit/layer_weights/__init__.py @@ -0,0 +1,38 @@ +import os +import importlib.util + +# 默认的load_image函数 +def default_load_image(image_path): + print(f"Loading image using default function: {image_path}") + # 默认的加载图像逻辑(这里只是示例) + return image_path + + +# 用户提供的目录路径 +directory = "./user_directory" + +# 设定默认的load_image函数为default_load_image +load_image = default_load_image + +# 检查目录中是否有pre_process.py文件 +pre_process_path = os.path.join(directory, "pre_process.py") + +if os.path.exists(pre_process_path): + print(f"Found pre_process.py in {directory}, attempting to load load_image from it.") + + # 使用importlib来加载模块 + spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) + pre_process = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pre_process) + + # 如果pre_process.py中有load_image函数,则替换默认函数 + if hasattr(pre_process, "load_image"): + load_image = pre_process.load_image + print("load_image function replaced by the one in pre_process.py.") + else: + print("load_image function not found in pre_process.py.") +else: + print(f"pre_process.py not found in {directory}, using default load_image.") + +# 使用当前的load_image函数 +image = load_image("path/to/image.jpg") diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 56fdfda80..9ef011005 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -8,7 +8,7 @@ from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights from lightllm.utils.log_utils import init_logger -from lightllm.models.internvl.img_process import load_image +from lightllm.models.vit import get_load_image_func import torchvision.transforms as T from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from PIL import Image @@ -42,6 +42,7 @@ def __init__(self, kvargs): self.data_type = kvargs.get("data_type", "float16") self.quant_type = kvargs.get("quant_type", None) self.quant_cfg_path = kvargs.get("quant_cfg", None) + self.load_image_func = get_load_image_func(self.weight_dir_) self._init_datatype() self._init_config() @@ -148,7 +149,7 @@ def encode(self, image_uuids: List): uuids.append(url) image_data = read_shm(get_shm_name_data(url)) image_data = Image.open(BytesIO(image_data)) - t = load_image(image_data) + t = self.load_image_func(image_data) img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(url), url)) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 110e1abfe..a797aef71 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -60,7 +60,6 @@ def exposed_init_model(self, kvargs): "weight_dir": weight_dir, "data_type": self.data_type, } - print(kvargs) self.model = VisionTransformer(kvargs) else: raise Exception(f"can not support {self.model_type} now") From 826cc23cfa81dd5380ff0ff0d6d2793b397a5785 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 16:13:24 +0800 Subject: [PATCH 05/17] vit quant & quant for dp + tp --- lightllm/common/quantization/ppl_quant.py | 6 ++++-- lightllm/common/quantization/torchao_quant.py | 4 +++- lightllm/common/quantization/vllm_quant.py | 12 +++++++----- .../router/model_infer/mode_backend/base_backend.py | 1 + lightllm/server/visualserver/manager.py | 2 ++ .../server/visualserver/model_infer/model_rpc.py | 2 ++ 6 files changed, 19 insertions(+), 8 deletions(-) diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index cb1d6a4ea..4015bda64 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -9,6 +10,7 @@ class PPLW4A16QuantizationMethod(QuantizationMethod): def __init__(self, group_size=128): super().__init__() self.group_size = group_size + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ @@ -17,7 +19,7 @@ def quantize(self, weight: torch.Tensor): qweight: [K, N//8] int32 (packed int4*8) new pack_order q_scale: [K//group_size, N] int32 """ - weight = weight.to(dtype=torch.float16).cuda() + weight = weight.to(dtype=torch.float16).cuda(self.device_id_) from lightllm_ppl_int4_kernel import int4_weight_encode qweight_new, q_scale = int4_weight_encode(weight, self.group_size) @@ -71,7 +73,7 @@ def quantize(self, weight: torch.Tensor): from flash_llm_fp6_llm import weight_quant_to_fp6 fp6_weight = weight_quant_to_fp6(quant_half, fp6_weight, True) - return fp6_weight.cuda(), scale.half().contiguous().cuda() + return fp6_weight.cuda(self.device_id_), scale.half().contiguous().cuda(self.device_id_) def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): """ """ diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index e2e93d1ea..812d8e58d 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -32,11 +33,12 @@ def __init__(self): assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ """ dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - dummy_linear.weight = torch.nn.Parameter(weight.cuda()) + dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) return dummy_linear.weight diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index 1ba7bf4d9..e3df018c2 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -1,3 +1,4 @@ +import os import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS @@ -15,6 +16,7 @@ class vLLMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use quant api of it" + self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) def quantize(self, weight: torch.Tensor): """ """ @@ -32,12 +34,12 @@ def __init__(self): def quantize(self, weight: torch.Tensor): if hasattr(weight, "scale"): - return weight.data.transpose(0, 1).cuda(), weight.scale.cuda() + return weight.data.transpose(0, 1).cuda(self.device_id_), weight.scale.cuda(self.device_id_) weight = weight.float() scale = weight.abs().max(dim=-1)[0] / 127 weight = weight.transpose(0, 1) / scale.reshape(1, -1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(), scale.cuda() + return weight.cuda(self.device_id_), scale.cuda(self.device_id_) def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True) @@ -61,7 +63,7 @@ def quantize(self, weight: torch.Tensor): if self.is_moe: return self.quantize_moe(weight) qweight, weight_scale = ops.scaled_fp8_quant( - weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=True + weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) return qweight.transpose(0, 1), weight_scale @@ -69,10 +71,10 @@ def quantize_moe(self, weight): num_experts = weight.shape[0] qweights = [] weight_scales = [] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda() + qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) for i in range(num_experts): qweight, weight_scale = ops.scaled_fp8_quant( - weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False + weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False ) qweights[i] = qweight weight_scales.append(weight_scale) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4bb1a1a2a..6a0adc6f9 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -77,6 +77,7 @@ def init_model(self, kvargs): self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None) max_total_token_num = kvargs["max_total_token_num"] + os.environ["CURRENT_DEVICE_ID"] = str(self.tp_rank) torch.cuda.set_device(self.tp_rank) dist.init_process_group( diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b03a27f6f..6ab3e70df 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -70,6 +70,8 @@ async def wait_to_model_ready(self): "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "visual_gpu_ids": self.args.visual_gpu_ids, + "quant_type": self.args.quant_type, + "quant_cfg": self.args.quant_cfg, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a797aef71..f25767167 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -59,6 +59,8 @@ def exposed_init_model(self, kvargs): "world_size": self.vit_tp, "weight_dir": weight_dir, "data_type": self.data_type, + "quant_type": kvargs["quant_type"], + "quant_cfg": kvargs["quant_cfg"], } self.model = VisionTransformer(kvargs) else: From 3fbf429c335819ba101b2b6da44d25b9acb6221c Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 16:32:49 +0800 Subject: [PATCH 06/17] add quant use_custom_alloc --- .../layer_weights/meta_weights/mm_weight.py | 4 +++- lightllm/common/quantization/ppl_quant.py | 9 +++++--- .../common/quantization/quantize_method.py | 2 +- lightllm/common/quantization/torchao_quant.py | 2 +- lightllm/common/quantization/vllm_quant.py | 22 ++++++++++++------- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index 0d91fdc95..a0f60823a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -28,7 +28,9 @@ def set_quant_method(self, quant_method): def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True): if self.quant_method is not None: - return self.quant_method.apply(input_tensor, self.weight, self.bias, out) + return self.quant_method.apply( + input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger + ) if out is None: shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index 4015bda64..e87f2b1e2 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -25,7 +25,7 @@ def quantize(self, weight: torch.Tensor): qweight_new, q_scale = int4_weight_encode(weight, self.group_size) return qweight_new, q_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): """ input_tensor is activation: (M, K) float16 weights: [qweight, scale_weight] @@ -40,7 +40,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): shape = (input_tensor.shape[0], qweight.shape[0] * 8) dtype = input_tensor.dtype device = input_tensor.device - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + else: + out = torch.empty(shape, dtype, device=device) from lightllm_ppl_int4_kernel import matmul_i4_fp16 from lightllm_ppl_int4_kernel import int4_weight_decode @@ -75,7 +78,7 @@ def quantize(self, weight: torch.Tensor): fp6_weight = weight_quant_to_fp6(quant_half, fp6_weight, True) return fp6_weight.cuda(self.device_id_), scale.half().contiguous().cuda(self.device_id_) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): """ """ from flash_llm_fp6_llm import linear_forward_cuda diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index e50938f5a..e007e0e5f 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -11,5 +11,5 @@ def quantize(self, weights: torch.Tensor): pass @abstractmethod - def apply(self, input_tensor, weight, bias=None, out=None): + def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): pass diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index 812d8e58d..e35b99b44 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -42,7 +42,7 @@ def quantize(self, weight: torch.Tensor): quantize_(dummy_linear, self.quant_func) return dummy_linear.weight - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True): return F.linear(input_tensor, weights, bias) diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index 6c0b23745..82b7790d1 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -41,7 +41,7 @@ def quantize(self, weight: torch.Tensor): weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) return weight.cuda(self.device_id_), scale.cuda(self.device_id_) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): input_scale = None if len(weights) == 3: qweight, weight_scale, input_scale = weights @@ -54,9 +54,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): m = input_tensor.shape[0] n = qweight.shape[1] if out is None: - out = g_cache_manager.alloc_tensor( - (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False - ) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out @@ -89,13 +92,16 @@ def quantize_moe(self, weight): weight_scale = torch.cat(weight_scales, dim=0).reshape(-1) return qweights, weight_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = weights[0].shape[1] if out is None: - out = g_cache_manager.alloc_tensor( - (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False - ) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) return out From 2c7c0ff83af9358e05d38ff066677a1e84034f91 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 18:54:23 +0800 Subject: [PATCH 07/17] fix precision(ongoing) --- .../vit/layer_infer/transformer_layer_infer.py | 15 +++++++++++---- .../vit/layer_weights/transformer_layer_weight.py | 1 + lightllm/models/vit/model.py | 1 + .../server/visualserver/model_infer/model_rpc.py | 1 + 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index fe030f321..a4a7ac4da 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -31,10 +31,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): def norm(self, input, weight): input_dtype = input.dtype + input_shape = input.shape + input = input.view(-1, self.tp_padding_head_num * self.head_dim_) input = input.to(torch.float32) variance = input.pow(2).mean(-1, keepdim=True) input = input * torch.rsqrt(variance + self.eps_) - return weight * input.to(input_dtype) + out = weight * input.to(input_dtype) + out = out.reshape(input_shape) + return out def tp_norm(self, input, weight): input_shape = input.shape @@ -42,7 +46,7 @@ def tp_norm(self, input, weight): input_dtype = input.dtype input = input.to(torch.float32) tp_variance = input.pow(2).sum(-1, keepdim=True) - dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) + # dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) variance = tp_variance / self.embed_dim_ input = input * torch.rsqrt(variance + self.eps_) out = weight * input.to(input_dtype) @@ -75,8 +79,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten ) def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight) - k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight) + q_norm = self.norm(q, layer_weight.q_norm_weight_.weight) + k_norm = self.norm(k, layer_weight.k_norm_weight_.weight) return q_norm, k_norm def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: @@ -85,6 +89,9 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_) q, k, v = qkv.unbind(2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() return q, k, v def _context_attention_kernel(self, q, k, v) -> torch.Tensor: diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 2676ff3ed..199ee8d19 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -163,5 +163,6 @@ def load_hf_weights(self, weights): ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] self.ls2 = self._cuda(ls2) self.use_ls = True + print(self.ls1) return super().load_hf_weights(weights) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 9ef011005..025809c95 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -163,6 +163,7 @@ def encode(self, image_uuids: List): imgs = torch.cat(img_tensors, dim=0) pixel_values = imgs.cuda().to(dtype=self.data_type) + print(pixel_values.shape, pixel_values.dtype) all_img_embeds = self.forward(pixel_values) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index f25767167..1bce4b926 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -63,6 +63,7 @@ def exposed_init_model(self, kvargs): "quant_cfg": kvargs["quant_cfg"], } self.model = VisionTransformer(kvargs) + # self.model = InternVLVisionModel() else: raise Exception(f"can not support {self.model_type} now") From 71a9e560893fa7d4e905374c40b36fc1d0c30141 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 23:01:11 +0800 Subject: [PATCH 08/17] fix vit6b precision --- .../layer_infer/transformer_layer_infer.py | 10 ++--- .../layer_weights/transformer_layer_weight.py | 1 - .../vit/triton_kernel/flashattention_nopad.py | 40 ++++++++++++------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index a4a7ac4da..f13fcf428 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -46,7 +46,8 @@ def tp_norm(self, input, weight): input_dtype = input.dtype input = input.to(torch.float32) tp_variance = input.pow(2).sum(-1, keepdim=True) - # dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) + if self.world_size_ > 1: + dist.all_reduce(tp_variance, op=dist.ReduceOp.SUM, async_op=False) variance = tp_variance / self.embed_dim_ input = input * torch.rsqrt(variance + self.eps_) out = weight * input.to(input_dtype) @@ -79,8 +80,8 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten ) def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: - q_norm = self.norm(q, layer_weight.q_norm_weight_.weight) - k_norm = self.norm(k, layer_weight.k_norm_weight_.weight) + q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight) + k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight) return q_norm, k_norm def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: @@ -89,9 +90,6 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False) qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_) q, k, v = qkv.unbind(2) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() return q, k, v def _context_attention_kernel(self, q, k, v) -> torch.Tensor: diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 199ee8d19..2676ff3ed 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -163,6 +163,5 @@ def load_hf_weights(self, weights): ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] self.ls2 = self._cuda(ls2) self.use_ls = True - print(self.ls1) return super().load_hf_weights(weights) diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 4a05ff635..062ff51f5 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -22,6 +22,14 @@ def _fwd_kernel( q_stride_s, q_stride_h, q_stride_d, + k_stride_b, + k_stride_s, + k_stride_h, + k_stride_d, + v_stride_b, + v_stride_s, + v_stride_h, + v_stride_d, o_stride_b, o_stride_s, o_stride_h, @@ -30,9 +38,9 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): - cur_batch = tl.program_id(0) + cur_batch = tl.program_id(2) cur_head = tl.program_id(1) - start_m = tl.program_id(2) + start_m = tl.program_id(0) # initialize offsets offs_n = tl.arange(0, BLOCK_N) @@ -49,9 +57,9 @@ def _fwd_kernel( start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- off_k = ( - cur_batch * q_stride_b - + (start_n + offs_n[None, :]) * q_stride_s - + cur_head * q_stride_h + cur_batch * k_stride_b + + (start_n + offs_n[None, :]) * k_stride_s + + cur_head * k_stride_h + offs_d[:, None] ) k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < seq_len, other=0.0) @@ -71,9 +79,9 @@ def _fwd_kernel( # update acc off_v = ( - cur_batch * q_stride_b - + (start_n + offs_n[:, None]) * q_stride_s - + cur_head * q_stride_h + cur_batch * v_stride_b + + (start_n + offs_n[:, None]) * v_stride_s + + cur_head * v_stride_h + offs_d[None, :] ) v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < seq_len, other=0.0) @@ -104,8 +112,8 @@ def flash_attention_fwd( batch_size, seq_len, head_num, head_dim = q.shape sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数 - grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head, - # grid = (triton.cdiv(seq_len, BLOCK), batch_size, head_num) # batch, head, + # grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head, + grid = (triton.cdiv(seq_len, BLOCK), head_num, batch_size) # batch, head, num_warps = 4 _fwd_kernel[grid]( q, @@ -118,6 +126,14 @@ def flash_attention_fwd( q.stride(1), q.stride(2), q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), o.stride(0), o.stride(1), o.stride(2), @@ -157,7 +173,6 @@ def test(): k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - torch_out = torch_att(q, k, v) import time @@ -174,6 +189,3 @@ def test(): print("max ", torch.max(torch.abs(torch_out - o))) print("mean ", torch.mean(torch.abs(torch_out - o))) assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -# test() From 9fb8f4edfeda812dca0680afd4543326d00cd062 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 11 Dec 2024 23:25:03 +0800 Subject: [PATCH 09/17] add statc speed script for vit --- test/model/model_infer_vit.py | 72 +++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 test/model/model_infer_vit.py diff --git a/test/model/model_infer_vit.py b/test/model/model_infer_vit.py new file mode 100644 index 000000000..65d0a7915 --- /dev/null +++ b/test/model/model_infer_vit.py @@ -0,0 +1,72 @@ +import numpy as np +from multiprocessing import Queue +import multiprocessing +import os +import time + +from lightllm.models.vit.model import VisionTransformer + + +def test_model_inference(world_size, weight_dir, quant_type=None): + workers = [] + for rank_id in range(world_size): + kvargs = { + "tp_rank": rank_id, + "world_size": world_size, + "weight_dir": weight_dir, + "data_type": "bf16", + "quant_type": quant_type, + "quant_cfg": None, + } + + proc = multiprocessing.Process(target=tppart_model_infer, args=(kvargs,)) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + return + + +def tppart_model_infer(model_kvargs): + import torch + from lightllm.distributed import set_custom_reduce + import torch.distributed as dist + + rank_id = model_kvargs["tp_rank"] + world_size = model_kvargs["world_size"] + + torch.cuda.set_device(rank_id) + os.environ["CURRENT_DEVICE_ID"] = str(rank_id) + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) + set_custom_reduce() + dist.barrier() + torch.cuda.empty_cache() + model_part = VisionTransformer(model_kvargs) + test_data = torch.randn((13, 3, 448, 448)).cuda().to(torch.bfloat16) + # warm up + torch.cuda.synchronize() + for i in range(10): + model_part.forward(test_data) + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.time() + for i in range(50): + model_part.forward(test_data) + torch.cuda.synchronize() + end_time = time.time() + + if rank_id == 0: + print("time total cost(ms):", (end_time - start_time) / 50 * 1000) + + return + + +if __name__ == "__main__": + import torch + + world_size = 2 + weight_dir = "your_multimodal_vit_path" + torch.multiprocessing.set_start_method("spawn") + test_model_inference(world_size, weight_dir, "vllm-w8a8") From 0d71d30df18b7c46ffe5c63c149aa5dc27d3565d Mon Sep 17 00:00:00 2001 From: shihaobai Date: Mon, 23 Dec 2024 12:55:34 +0800 Subject: [PATCH 10/17] --vit_quant_type --- lightllm/server/api_cli.py | 15 +++++++++++++++ lightllm/server/visualserver/manager.py | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 28e066425..544f2e8cc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -244,6 +244,21 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Path of quantization config. It can be used for mixed quantization. Examples can be found in lightllm/common/quantization/configs.""", ) + parser.add_argument( + "--vit_quant_type", + type=str, + default=None, + help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16 + | ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16 + | vllm-w8a8 | vllm-fp8w8a8""", + ) + parser.add_argument( + "--vit_quant_cfg", + type=str, + default=None, + help="""Path of quantization config. It can be used for mixed quantization. + Examples can be found in lightllm/common/quantization/configs.""", + ) parser.add_argument( "--static_quant", action="store_true", diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 6ab3e70df..a74b25be0 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -70,8 +70,8 @@ async def wait_to_model_ready(self): "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "visual_gpu_ids": self.args.visual_gpu_ids, - "quant_type": self.args.quant_type, - "quant_cfg": self.args.quant_cfg, + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) From 2f45df5499d6efaa1b2fce5dd433bb9dd6ff22d4 Mon Sep 17 00:00:00 2001 From: shihaobai Date: Mon, 23 Dec 2024 14:03:49 +0800 Subject: [PATCH 11/17] get custom get_image_path --- lightllm/models/internvl/model.py | 9 ++++++--- lightllm/models/vit/__init__.py | 20 ++++++++++++++++++++ lightllm/server/httpserver/manager.py | 2 +- lightllm/server/tokenizer.py | 13 +++++++------ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 8aa127e83..02abf9a9f 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -13,7 +13,9 @@ ) from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight from lightllm.models.llava.llava_visual import LlavaVisionModel -from lightllm.models.internvl.img_process import get_image_patch + +# from lightllm.models.internvl.img_process import get_image_patch +from lightllm.models.vit import get_image_patch_func from typing import Dict import lightllm.models.internvl.internvl_visual import torch @@ -36,9 +38,10 @@ def __init__(self, tokenizer, model_cfg, **kwargs): self.image_end_tag = IMG_END_TOKEN self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"]) def get_image_token_length(self, img: ImageItem): - return get_image_patch(img.image_w, img.image_h, use_thumbnail=True) * self.image_length + return self.get_image_patch(img.image_w, img.image_h, use_thumbnail=True) * self.image_length # only change the impl of the encode func: def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): @@ -47,7 +50,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): image_count = len(multimodal_params.images) prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - origin_ids = self.tokenizer.encode(prompt) + origin_ids = self.tokenizer.encode(prompt, kwargs["add_special_tokens"]) # --> id,id+1...id+num input_ids = [] image_id = 0 diff --git a/lightllm/models/vit/__init__.py b/lightllm/models/vit/__init__.py index cbe40d195..eb10b56bc 100644 --- a/lightllm/models/vit/__init__.py +++ b/lightllm/models/vit/__init__.py @@ -2,6 +2,7 @@ import importlib.util from lightllm.utils.log_utils import init_logger from lightllm.models.internvl.img_process import load_image as default_load_image +from lightllm.models.internvl.img_process import get_image_patch as default_get_image_patch logger = init_logger(__name__) @@ -23,3 +24,22 @@ def get_load_image_func(weight_dir): logger.info(f"pre_process.py not found in {weight_dir}, using default load_image.") return default_load_image + + +def get_image_patch_func(weight_dir): + global load_image + pre_process_path = os.path.join(weight_dir, "pre_process.py") + if os.path.exists(pre_process_path): + logger.info(f"Found pre_process.py in {weight_dir}, attempting to load get_image_patch from it.") + spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) + pre_process = importlib.util.module_from_spec(spec) + spec.loader.exec_module(pre_process) + if hasattr(pre_process, "get_image_patch"): + logger.info("load_image function replaced by the one in pre_process.py.") + return pre_process.get_image_patch + else: + logger.info("get_image_patch function not found in pre_process.py.") + else: + logger.info(f"pre_process.py not found in {weight_dir}, using default get_image_patch.") + + return default_get_image_patch diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index ae7d6f1e6..7bde5724f 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -206,7 +206,7 @@ async def _encode( if self.enable_multimodal: assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!" await self._alloc_multimodal_resources(multimodal_params) - prompt_ids = self.tokenizer.encode(prompt, multimodal_params) + prompt_ids = self.tokenizer.encode(prompt, multimodal_params, add_special_tokens=add_special_tokens) else: prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) return prompt_ids diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e2d4452fb..35c4e846d 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -67,6 +67,12 @@ def get_tokenizer( kwargs["use_fast"] = False tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.info( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." + ) + model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) model_type = model_cfg.get("model_type", "") if model_type == "llava" or model_type == "internlmxcomposer2": @@ -79,11 +85,6 @@ def get_tokenizer( image_processor = AutoProcessor.from_pretrained(tokenizer_name) tokenizer = QWen2VLTokenizer(tokenizer=tokenizer, image_processor=image_processor, model_cfg=model_cfg) elif model_type == "internvl_chat": - tokenizer = InternvlTokenizer(tokenizer, model_cfg) + tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) - if not isinstance(tokenizer, PreTrainedTokenizerFast): - logger.info( - "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead." - ) return tokenizer From e96befb28e931a054010385f0603b0b776f0989b Mon Sep 17 00:00:00 2001 From: shihaobai Date: Mon, 23 Dec 2024 14:19:56 +0800 Subject: [PATCH 12/17] fix internvl tokenizer --- lightllm/models/internvl/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 02abf9a9f..d62e20bdf 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -50,7 +50,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): image_count = len(multimodal_params.images) prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) - origin_ids = self.tokenizer.encode(prompt, kwargs["add_special_tokens"]) + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) # --> id,id+1...id+num input_ids = [] image_id = 0 From 83dc433440c39793569de284c5ea545591c76566 Mon Sep 17 00:00:00 2001 From: shihaobai Date: Mon, 23 Dec 2024 14:26:50 +0800 Subject: [PATCH 13/17] fix 0.3b --- lightllm/models/vit/layer_weights/transformer_layer_weight.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 2676ff3ed..dbe78fda5 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -149,10 +149,10 @@ def load_hf_weights(self, weights): weights[self._v_bias_name] = v_bias_ del weights[f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.bias"] - if self._q_norm_weight_name in weights: + if self.qk_norm and self._q_norm_weight_name in weights: weights[self._q_norm_weight_name] = F.pad(weights[self._q_norm_weight_name], (0, self.padding_hidden_size)) - if self._k_norm_weight_name in weights: + if self.qk_norm and self._k_norm_weight_name in weights: weights[self._k_norm_weight_name] = F.pad(weights[self._k_norm_weight_name], (0, self.padding_hidden_size)) if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: From cd0770727ef3dd73807462ff213ea304af9204d2 Mon Sep 17 00:00:00 2001 From: shihaobai Date: Tue, 24 Dec 2024 19:54:15 +0800 Subject: [PATCH 14/17] fix vit --- lightllm/models/internvl/internvl_visual.py | 20 +++++++--- lightllm/models/internvl/model.py | 2 +- lightllm/models/vit/layer_weights/__init__.py | 38 ------------------- 3 files changed, 15 insertions(+), 45 deletions(-) diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index 4107a0b46..733ffb711 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -11,6 +11,7 @@ from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO from lightllm.models.internvl.img_process import load_image +from lightllm.models.vit import get_load_image_func from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -24,13 +25,20 @@ def load_model(self, weight_dir): assert torch.cuda.is_available() self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 self.config = json.load(open(os.path.join(weight_dir, "config.json"))) - self.model = AutoModel.from_pretrained( - weight_dir, - torch_dtype=self.dtype, - trust_remote_code=True, - language_model="fake_language_model", + # self.model = AutoModel.from_pretrained( + # weight_dir, + # torch_dtype=self.dtype, + # trust_remote_code=True, + # language_model="fake_language_model", + # ) + from internvl_chat import InternVLChatModel, InternVLChatConfig + + cfg = InternVLChatConfig.from_pretrained(weight_dir) + self.model = InternVLChatModel.from_pretrained( + weight_dir, config=cfg, torch_dtype=self.dtype, language_model="fake_language_model" ) self.model.eval().cuda() + self.load_image_func = get_load_image_func(weight_dir) def cuda(self): return self @@ -46,7 +54,7 @@ def encode(self, image_uuids: List): uuids.append(url) image_data = read_shm(get_shm_name_data(url)) image_data = Image.open(BytesIO(image_data)) - t = load_image(image_data) + t = self.load_image_func(image_data) img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(url), url)) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index d62e20bdf..d32769886 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -41,7 +41,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs): self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"]) def get_image_token_length(self, img: ImageItem): - return self.get_image_patch(img.image_w, img.image_h, use_thumbnail=True) * self.image_length + return self.get_image_patch_func(img.image_w, img.image_h, use_thumbnail=True) * self.image_length # only change the impl of the encode func: def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): diff --git a/lightllm/models/vit/layer_weights/__init__.py b/lightllm/models/vit/layer_weights/__init__.py index 08bcd04cf..e69de29bb 100644 --- a/lightllm/models/vit/layer_weights/__init__.py +++ b/lightllm/models/vit/layer_weights/__init__.py @@ -1,38 +0,0 @@ -import os -import importlib.util - -# 默认的load_image函数 -def default_load_image(image_path): - print(f"Loading image using default function: {image_path}") - # 默认的加载图像逻辑(这里只是示例) - return image_path - - -# 用户提供的目录路径 -directory = "./user_directory" - -# 设定默认的load_image函数为default_load_image -load_image = default_load_image - -# 检查目录中是否有pre_process.py文件 -pre_process_path = os.path.join(directory, "pre_process.py") - -if os.path.exists(pre_process_path): - print(f"Found pre_process.py in {directory}, attempting to load load_image from it.") - - # 使用importlib来加载模块 - spec = importlib.util.spec_from_file_location("pre_process", pre_process_path) - pre_process = importlib.util.module_from_spec(spec) - spec.loader.exec_module(pre_process) - - # 如果pre_process.py中有load_image函数,则替换默认函数 - if hasattr(pre_process, "load_image"): - load_image = pre_process.load_image - print("load_image function replaced by the one in pre_process.py.") - else: - print("load_image function not found in pre_process.py.") -else: - print(f"pre_process.py not found in {directory}, using default load_image.") - -# 使用当前的load_image函数 -image = load_image("path/to/image.jpg") From d4e4345ad7df3d0f945b42e3be0ab0688f7e6f0f Mon Sep 17 00:00:00 2001 From: shihaobai Date: Wed, 25 Dec 2024 14:12:20 +0800 Subject: [PATCH 15/17] add device utils --- .../layer_weights/meta_weights/base_weight.py | 3 ++- .../meta_weights/fused_moe_weight.py | 8 +++----- .../layer_weights/meta_weights/norm_weight.py | 5 +++-- lightllm/common/quantization/ppl_quant.py | 3 ++- lightllm/common/quantization/torchao_quant.py | 3 ++- lightllm/common/quantization/vllm_quant.py | 3 ++- .../pre_and_post_layer_weight.py | 5 +++-- .../layer_weights/transformer_layer_weight.py | 6 +++--- .../model_infer/mode_backend/base_backend.py | 1 - .../visualserver/model_infer/model_rpc.py | 4 +++- lightllm/utils/device_utils.py | 19 +++++++++++++++++++ 11 files changed, 42 insertions(+), 18 deletions(-) create mode 100644 lightllm/utils/device_utils.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index dcfbb247c..810a633a3 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -2,6 +2,7 @@ import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_world_size, get_rank +from lightllm.utils.device_utils import get_current_device_id class BaseWeight(ABC): @@ -21,7 +22,7 @@ class BaseWeightTpl(BaseWeight): def __init__(self): self.world_size_ = get_world_size() self.tp_rank_ = get_rank() - self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) + self.device_id_ = get_current_device_id() def load_hf_weights(self, weights): pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index 6425f0ae0..4c5b849fc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -7,6 +7,7 @@ from lightllm.common.vllm_kernel import _custom_ops as ops from lightllm.common.fused_moe import fused_experts +from lightllm.utils.device_utils import get_current_device_id class FusedMoeWeight(BaseWeight): @@ -22,7 +23,6 @@ def __init__( self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_rank() - self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID", self.tp_rank_)) self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts self.expert_gate_up_proj_etp = None @@ -179,10 +179,8 @@ def load_hf_weights(self, weights): self._fuse() def _cuda(self, cpu_tensor): - if self.device_id_ is None: - return cpu_tensor.contiguous().to(self.data_type_).cuda() - else: - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) + device_id = get_current_device_id() + return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) def verify_load(self): if os.environ.get("ETP_MODE_ENABLED") == "true": diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 5ba48516c..8d593b603 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,3 +1,4 @@ +import torch from .base_weight import BaseWeightTpl @@ -41,8 +42,8 @@ def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): self.split_n_embed = split_n_embed def load_hf_weights(self, weights): - start = self.split_n_embed * self.device_id_ - end = self.split_n_embed * (self.device_id_ + 1) + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) if self.weight_name in weights: self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_) diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index e87f2b1e2..0dc29fe25 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -3,6 +3,7 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.utils.device_utils import get_current_device_id @QUANTMETHODS.register("ppl-w4a16-128") @@ -10,7 +11,7 @@ class PPLW4A16QuantizationMethod(QuantizationMethod): def __init__(self, group_size=128): super().__init__() self.group_size = group_size - self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) + self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index e35b99b44..79f3c1391 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -3,6 +3,7 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.utils.device_utils import get_current_device_id import torch.nn.functional as F try: @@ -33,7 +34,7 @@ def __init__(self): assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None - self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) + self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ """ diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index 82b7790d1..eb42b9366 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -3,6 +3,7 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.utils.device_utils import get_current_device_id import torch.nn.functional as F try: @@ -16,7 +17,7 @@ class vLLMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use quant api of it" - self.device_id_ = int(os.getenv("CURRENT_DEVICE_ID")) + self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ """ diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index f56a884a6..56947055f 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -3,6 +3,7 @@ import numpy as np import torch.nn.functional as F from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.utils.device_utils import get_current_device_id class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): @@ -12,11 +13,11 @@ def __init__(self, tp_rank, world_size, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] - self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) return def _cuda(self, cpu_tensor): - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + device_id = get_current_device_id() + return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) def _get_pos_embed(self, H, W): pos_embed = self.position_embedding[:, 1:, :] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index dbe78fda5..7f7a9a7bb 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,17 +11,17 @@ MultiROWMMWeight, TpNormWeight, ) +from lightllm.utils.device_utils import get_current_device_id class ViTTransformerLayerWeight(TransformerLayerWeight): def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None): super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) - self.gpu_id_ = int(os.getenv("CURRENT_DEVICE_ID", tp_rank)) - return def _cuda(self, cpu_tensor): - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.gpu_id_) + device_id = get_current_device_id() + return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) def _parse_config(self): self.padding_hidden_size = self.network_config_["padding_hidden_size"] diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 31c0d563a..180a873e2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -77,7 +77,6 @@ def init_model(self, kvargs): self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None) max_total_token_num = kvargs["max_total_token_num"] - os.environ["CURRENT_DEVICE_ID"] = str(self.tp_rank) if self.dp_size > 1: assert self.dp_size == self.world_size, "Currently only self-sustaining dp_size == tp_size" os.environ["ENABLE_DP"] = "1" diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 1bce4b926..e935503ae 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -16,6 +16,7 @@ from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.utils.device_utils import set_current_device_id class VisualModelRpcServer(rpyc.Service): @@ -36,6 +37,7 @@ def exposed_init_model(self, kvargs): self.data_type = kvargs["data_type"] torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) + print(visual_gpu_ids[self.vit_rank_id]) dist.init_process_group( backend="nccl", init_method=f"tcp://127.0.0.1:{visual_nccl_port}", @@ -43,7 +45,7 @@ def exposed_init_model(self, kvargs): world_size=self.vit_tp, ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) - os.environ["CURRENT_DEVICE_ID"] = str(visual_gpu_ids[self.vit_rank_id]) + set_current_device_id(visual_gpu_ids[self.vit_rank_id]) try: self.model_type = model_cfg["model_type"] diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py new file mode 100644 index 000000000..aec6399f1 --- /dev/null +++ b/lightllm/utils/device_utils.py @@ -0,0 +1,19 @@ +import os +from functools import lru_cache + + +@lru_cache(maxsize=None) +def set_current_device_id(device_id: int): + os.environ["CURRENT_DEVICE_ID"] = str(device_id) + + +@lru_cache(maxsize=None) +def get_current_device_id(): + import torch + + if torch.cuda.is_available(): + default_device_id = torch.cuda.current_device() + device_id = os.getenv("CURRENT_DEVICE_ID", default_device_id) + return int(device_id) + else: + raise RuntimeError("Torch CUDA is not avaliable.") From 0e6d6fbd5e9bc4a5162ac55a81b480817caf915a Mon Sep 17 00:00:00 2001 From: shihaobai Date: Wed, 25 Dec 2024 14:36:42 +0800 Subject: [PATCH 16/17] remove print --- lightllm/server/visualserver/model_infer/model_rpc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index e935503ae..46e52d693 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -37,7 +37,6 @@ def exposed_init_model(self, kvargs): self.data_type = kvargs["data_type"] torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) - print(visual_gpu_ids[self.vit_rank_id]) dist.init_process_group( backend="nccl", init_method=f"tcp://127.0.0.1:{visual_nccl_port}", From 260a9b61d6ec479ed87e7ba1dbbde316c38a201d Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Wed, 25 Dec 2024 16:31:41 +0800 Subject: [PATCH 17/17] fix --- .../basemodel/layer_weights/meta_weights/base_weight.py | 1 - lightllm/common/quantization/ppl_quant.py | 2 -- lightllm/common/quantization/quantize_method.py | 3 ++- lightllm/common/quantization/torchao_quant.py | 2 -- lightllm/common/quantization/vllm_quant.py | 2 -- .../server/router/model_infer/mode_backend/base_backend.py | 2 ++ lightllm/server/visualserver/model_infer/model_rpc.py | 3 ++- lightllm/utils/device_utils.py | 6 +++--- 8 files changed, 9 insertions(+), 12 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 810a633a3..762617b71 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,4 +1,3 @@ -import os import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_world_size, get_rank diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index 0dc29fe25..644c2174a 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -3,7 +3,6 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.utils.device_utils import get_current_device_id @QUANTMETHODS.register("ppl-w4a16-128") @@ -11,7 +10,6 @@ class PPLW4A16QuantizationMethod(QuantizationMethod): def __init__(self, group_size=128): super().__init__() self.group_size = group_size - self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index e007e0e5f..01d339674 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,10 +1,11 @@ import torch from abc import ABC, abstractmethod - +from lightllm.utils.device_utils import get_current_device_id class QuantizationMethod(ABC): def __init__(self): super().__init__() + self.device_id_ = get_current_device_id() @abstractmethod def quantize(self, weights: torch.Tensor): diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index 79f3c1391..b26e82352 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -3,7 +3,6 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.utils.device_utils import get_current_device_id import torch.nn.functional as F try: @@ -34,7 +33,6 @@ def __init__(self): assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None - self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ """ diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index eb42b9366..74685a3b9 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -3,7 +3,6 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.utils.device_utils import get_current_device_id import torch.nn.functional as F try: @@ -17,7 +16,6 @@ class vLLMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use quant api of it" - self.device_id_ = get_current_device_id() def quantize(self, weight: torch.Tensor): """ """ diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 180a873e2..44dddb9cc 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -39,6 +39,7 @@ from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock +from lightllm.utils.device_utils import set_current_device_id import torch.distributed as dist @@ -82,6 +83,7 @@ def init_model(self, kvargs): os.environ["ENABLE_DP"] = "1" torch.cuda.set_device(self.tp_rank) + set_current_device_id(self.tp_rank) dist.init_process_group( "nccl", diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 46e52d693..27adf4640 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -37,6 +37,7 @@ def exposed_init_model(self, kvargs): self.data_type = kvargs["data_type"] torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) + set_current_device_id(visual_gpu_ids[self.vit_rank_id]) dist.init_process_group( backend="nccl", init_method=f"tcp://127.0.0.1:{visual_nccl_port}", @@ -44,7 +45,7 @@ def exposed_init_model(self, kvargs): world_size=self.vit_tp, ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) - set_current_device_id(visual_gpu_ids[self.vit_rank_id]) + try: self.model_type = model_cfg["model_type"] diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index aec6399f1..43ede1000 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -2,7 +2,6 @@ from functools import lru_cache -@lru_cache(maxsize=None) def set_current_device_id(device_id: int): os.environ["CURRENT_DEVICE_ID"] = str(device_id) @@ -12,8 +11,9 @@ def get_current_device_id(): import torch if torch.cuda.is_available(): - default_device_id = torch.cuda.current_device() - device_id = os.getenv("CURRENT_DEVICE_ID", default_device_id) + device_id = os.getenv("CURRENT_DEVICE_ID", None) + if device_id is None: + raise RuntimeError("set_current_device_id must called first to set current device") return int(device_id) else: raise RuntimeError("Torch CUDA is not avaliable.")