From 1a9c61bd3e7499c17ab1a61b0bb337f8b1b87a40 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 12 Dec 2024 17:04:08 +0800 Subject: [PATCH] add basic glm-edge-v support --- python/llm/src/ipex_llm/transformers/convert.py | 11 +++++++++++ python/llm/src/ipex_llm/transformers/models/glm.py | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 828ea5563ea..4f1b0d3d63f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1504,6 +1504,17 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.GlmAttention, glm_attention_forward) glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward) convert_forward(model, module.GlmModel, glm_model_forward) + + if hasattr(model.model, "vision"): + # glm-edge-v series + vision_module_name = model.model.vision.__class__.__module__ + vision_module = importlib.import_module(vision_module_name) + from transformers.models.siglip.modeling_siglip import SiglipAttention + from ipex_llm.transformers.models.chatglm4v import vision_model_forward + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model, vision_module.VisionModel, vision_model_forward) + convert_forward(model, SiglipAttention, siglip_attention_forward) + elif "mpt" in model.config.model_type: if model.config.architectures is not None: modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index c82ebc32c61..fdd19705d6a 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -37,7 +37,6 @@ from typing import Optional, Tuple from transformers.cache_utils import Cache -from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base @@ -46,11 +45,12 @@ def merge_qkv(module: torch.nn.Module): - merge_qkv_base(module, GlmAttention) + merge_qkv_base(module, "GlmAttention") + merge_qkv_base(module, "SiglipAttention") def split_mlp(module: torch.nn.Module): - if isinstance(module, GlmMLP): + if module.__class__.__name__ == "GlmMLP": gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0) gate_proj = torch.nn.Linear(0, 0, bias=False)