diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 828ea5563ea..4f1b0d3d63f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1504,6 +1504,17 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.GlmAttention, glm_attention_forward) glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward) convert_forward(model, module.GlmModel, glm_model_forward) + + if hasattr(model.model, "vision"): + # glm-edge-v series + vision_module_name = model.model.vision.__class__.__module__ + vision_module = importlib.import_module(vision_module_name) + from transformers.models.siglip.modeling_siglip import SiglipAttention + from ipex_llm.transformers.models.chatglm4v import vision_model_forward + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model, vision_module.VisionModel, vision_model_forward) + convert_forward(model, SiglipAttention, siglip_attention_forward) + elif "mpt" in model.config.model_type: if model.config.architectures is not None: modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index c82ebc32c61..485a449d294 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -37,7 +37,6 @@ from typing import Optional, Tuple from transformers.cache_utils import Cache -from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base @@ -46,11 +45,12 @@ def merge_qkv(module: torch.nn.Module): - merge_qkv_base(module, GlmAttention) + merge_qkv_base(module, "GlmAttention") + merge_qkv_base(module, "SiglipAttention") def split_mlp(module: torch.nn.Module): - if isinstance(module, GlmMLP): + if module.__class__.__name__ == "GlmMLP": gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0) gate_proj = torch.nn.Linear(0, 0, bias=False) @@ -157,6 +157,7 @@ def glm_model_forward_wrapper(origin_forward): def glm_model_forward( self, input_ids: torch.LongTensor = None, + images: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, @@ -166,7 +167,7 @@ def glm_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs, + **kwargs, ): # ipex-llm changes start # IPEX-LLM OPT: kv cache and quantize kv cache @@ -187,6 +188,7 @@ def glm_model_forward( return origin_forward( self=self, input_ids=input_ids, + images=images, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -196,7 +198,7 @@ def glm_model_forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, - **flash_attn_kwargs, + **kwargs, ) return glm_model_forward