diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 1e8634b903f..227293e497d 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -42,6 +42,7 @@ import torch import torch.utils.checkpoint from torch import nn +from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache @@ -307,10 +308,16 @@ def pre_process_attn_and_mlp(module: torch.nn.Module): def add_lora(x: torch.Tensor, result: torch.Tensor, im_mask: torch.Tensor = None, lora_scaling: float = 0, Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None): - if im_mask is not None and torch.sum(im_mask) > 0: - part_x = x[im_mask] - result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling) - return result + invalidInputError(x.dim() == 3 and result.dim() == 3, + "`x` and `result` should have 3 dims") + if len(im_mask) == 0 or x.size(1) == 1: + return result + else: + for start_idx, end_idx in im_mask: + result[:, start_idx:end_idx, :] += Plora_B( + Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling + ) + return result def internlm_xcomposser2_attention_forward( @@ -457,18 +464,32 @@ def internlm_xcomposser2_chat( **kwargs, ): # ipex-llm changes start: fix device and dtype conversion + # replace im_mask with start_idx and end_idx to improve performance if image is None: inputs = self.build_inputs(tokenizer, query, history, meta_instruction) - im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() + im_mask = [] else: image = self.encode_img(image) inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image, history, meta_instruction) + mask = im_mask.cpu().flatten().tolist() + length = len(mask) + im_mask = [] + i = 0 + while i < length: + while i < length and not mask[i]: + i = i + 1 + start_idx = i + while i < length and mask[i]: + i = i + 1 + end_idx = i + if start_idx != end_idx: + im_mask.append((start_idx, end_idx)) + inputs = { k: v.to(device=self.device, dtype=self.dtype) for k, v in inputs.items() if torch.is_tensor(v) } - im_mask = im_mask.to(self.device) # ipex-llm changes end # also add end-of-assistant token in eos token id to avoid unnecessary generation