Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize internlm xcomposer2 performance #11550

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
Expand Down Expand Up @@ -307,10 +308,16 @@ def pre_process_attn_and_mlp(module: torch.nn.Module):
def add_lora(x: torch.Tensor, result: torch.Tensor,
im_mask: torch.Tensor = None, lora_scaling: float = 0,
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
if im_mask is not None and torch.sum(im_mask) > 0:
part_x = x[im_mask]
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
return result
invalidInputError(x.dim() == 3 and result.dim() == 3,
"`x` and `result` should have 3 dims")
if len(im_mask) == 0 or x.size(1) == 1:
return result
else:
for start_idx, end_idx in im_mask:
result[:, start_idx:end_idx, :] += Plora_B(
Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
)
return result


def internlm_xcomposser2_attention_forward(
Expand Down Expand Up @@ -457,18 +464,32 @@ def internlm_xcomposser2_chat(
**kwargs,
):
# ipex-llm changes start: fix device and dtype conversion
# replace im_mask with start_idx and end_idx to improve performance
if image is None:
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
im_mask = []
else:
image = self.encode_img(image)
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
history, meta_instruction)
mask = im_mask.cpu().flatten().tolist()
length = len(mask)
im_mask = []
i = 0
while i < length:
while i < length and not mask[i]:
i = i + 1
start_idx = i
while i < length and mask[i]:
i = i + 1
end_idx = i
if start_idx != end_idx:
im_mask.append((start_idx, end_idx))

inputs = {
k: v.to(device=self.device, dtype=self.dtype)
for k, v in inputs.items() if torch.is_tensor(v)
}
im_mask = im_mask.to(self.device)
# ipex-llm changes end

# also add end-of-assistant token in eos token id to avoid unnecessary generation
Expand Down
Loading