From 3eeb02f1bee93a6f8a7163886214b173cfecfb20 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 19 Dec 2024 17:23:01 +0800 Subject: [PATCH] support Megrez-3B-Omni (#12582) --- .../llm/src/ipex_llm/transformers/convert.py | 30 +++++++++++++++++++ .../ipex_llm/transformers/models/common.py | 4 +-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index ada7be90ffd..66df48a5f72 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1055,6 +1055,12 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "minicpm" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" + elif model.config.model_type == "megrezo": + from ipex_llm.transformers.models.minicpmv import merge_qkv + model.vision.apply(merge_qkv) + model.llm.config.model_type = "llama" + _optimize_pre(model.llm, qtype=qtype) + model.llm.config.model_type = "megrezo" elif model.config.model_type == "chatglm": if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024: # chatglm2 and chatglm3 @@ -2202,5 +2208,29 @@ def safe_bmm_fwd(*args, **kwargs): convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward) minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat) model.chat = MethodType(minicpmv_chat, model) + elif model.config.model_type == "megrezo": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper + minicpmv_generate = minicpmv_generate_wrapper(module.MegrezO.generate) + model.generate = MethodType(minicpmv_generate, model) + + # vision + vpm_modeling_module_name = model.vision.vpm.__class__.__module__ + vpm_module = importlib.import_module(vpm_modeling_module_name) + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model.vision.vpm, vpm_module.SiglipAttention, siglip_attention_forward) + + # resampler + from ipex_llm.transformers.models.minicpmv import _in_projection_packed + resampler_module_name = model.vision.resampler.__class__.__module__ + resampler_module = importlib.import_module(resampler_module_name) + resampler_module._in_projection_packed = _in_projection_packed + + # llm + model.llm.config.model_type = "llama" + model.llm.config.rope_scaling = {"rope_type": "default"} + _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) + model.llm.config.model_type = "megrezo" return model diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 0c140c5c68f..4303dbd3a18 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -198,8 +198,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de elif seq_length != kv_length and seq_length <= 32: mask = None else: - mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min, - dtype=dtype, device=device) + mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device) + mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length]) else: if seq_length != kv_length and seq_length <= 32: