Skip to content

Commit

Permalink
optimize internlm xcomposer2 performance (#11550)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 10, 2024
1 parent 3c16c9f commit 82f9514
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
Expand Down Expand Up @@ -307,10 +308,16 @@ def pre_process_attn_and_mlp(module: torch.nn.Module):
def add_lora(x: torch.Tensor, result: torch.Tensor,
im_mask: torch.Tensor = None, lora_scaling: float = 0,
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
if im_mask is not None and torch.sum(im_mask) > 0:
part_x = x[im_mask]
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
return result
invalidInputError(x.dim() == 3 and result.dim() == 3,
"`x` and `result` should have 3 dims")
if len(im_mask) == 0 or x.size(1) == 1:
return result
else:
for start_idx, end_idx in im_mask:
result[:, start_idx:end_idx, :] += Plora_B(
Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
)
return result


def internlm_xcomposser2_attention_forward(
Expand Down Expand Up @@ -457,18 +464,32 @@ def internlm_xcomposser2_chat(
**kwargs,
):
# ipex-llm changes start: fix device and dtype conversion
# replace im_mask with start_idx and end_idx to improve performance
if image is None:
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
im_mask = []
else:
image = self.encode_img(image)
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
history, meta_instruction)
mask = im_mask.cpu().flatten().tolist()
length = len(mask)
im_mask = []
i = 0
while i < length:
while i < length and not mask[i]:
i = i + 1
start_idx = i
while i < length and mask[i]:
i = i + 1
end_idx = i
if start_idx != end_idx:
im_mask.append((start_idx, end_idx))

inputs = {
k: v.to(device=self.device, dtype=self.dtype)
for k, v in inputs.items() if torch.is_tensor(v)
}
im_mask = im_mask.to(self.device)
# ipex-llm changes end

# also add end-of-assistant token in eos token id to avoid unnecessary generation
Expand Down

0 comments on commit 82f9514

Please sign in to comment.