Skip to content

Commit

Permalink
support Megrez-3B-Omni (#12582)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 19, 2024
1 parent 4e7e988 commit 3eeb02f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
30 changes: 30 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "minicpm"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
elif model.config.model_type == "megrezo":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.vision.apply(merge_qkv)
model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "megrezo"
elif model.config.model_type == "chatglm":
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
# chatglm2 and chatglm3
Expand Down Expand Up @@ -2202,5 +2208,29 @@ def safe_bmm_fwd(*args, **kwargs):
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
model.chat = MethodType(minicpmv_chat, model)
elif model.config.model_type == "megrezo":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate)
model.generate = MethodType(minicpmv_generate, model)

# vision
vpm_modeling_module_name = model.vision.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward)

# resampler
from ipex_llm.transformers.models.minicpmv import _in_projection_packed
resampler_module_name = model.vision.resampler.__class__.__module__
resampler_module = importlib.import_module(resampler_module_name)
resampler_module._in_projection_packed = _in_projection_packed

# llm
model.llm.config.model_type = "llama"
model.llm.config.rope_scaling = {"rope_type": "default"}
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "megrezo"

return model
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32:
Expand Down

0 comments on commit 3eeb02f

Please sign in to comment.