Skip to content

Commit

Permalink
add basic glm-edge-v support
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 12, 2024
1 parent 3e0823d commit 1a9c61b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 11 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,17 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, module.GlmAttention, glm_attention_forward)
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
convert_forward(model, module.GlmModel, glm_model_forward)

if hasattr(model.model, "vision"):
# glm-edge-v series
vision_module_name = model.model.vision.__class__.__module__
vision_module = importlib.import_module(vision_module_name)
from transformers.models.siglip.modeling_siglip import SiglipAttention
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, vision_module.VisionModel, vision_model_forward)
convert_forward(model, SiglipAttention, siglip_attention_forward)

elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
modeling_module_name = model.__class__.__module__
Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from typing import Optional, Tuple
from transformers.cache_utils import Cache
from transformers.models.glm.modeling_glm import GlmAttention, GlmMLP
from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.common import merge_qkv_base
Expand All @@ -46,11 +45,12 @@


def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, GlmAttention)
merge_qkv_base(module, "GlmAttention")
merge_qkv_base(module, "SiglipAttention")


def split_mlp(module: torch.nn.Module):
if isinstance(module, GlmMLP):
if module.__class__.__name__ == "GlmMLP":
gate_weight, up_weight = module.gate_up_proj.weight.data.chunk(2, dim=0)

gate_proj = torch.nn.Linear(0, 0, bias=False)
Expand Down

0 comments on commit 1a9c61b

Please sign in to comment.