diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 081178d0ae8..f144a461902 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1325,7 +1325,6 @@ def _optimize_post(model): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward - from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward from ipex_llm.transformers.models.chatglm2 import mlp_forward @@ -1338,9 +1337,7 @@ def _optimize_post(model): convert_forward(model, module.ChatGLMModel, chatglm2_model_forward) - convert_forward(model, - module.RMSNorm, - chatglm_rms_norm_forward) + convert_forward(model, module.RMSNorm, rms_norm_forward) convert_forward(model, module.MLP, mlp_forward) # for codegeex-nano if hasattr(model.config, "rope_ratio"): @@ -1358,8 +1355,7 @@ def _optimize_post(model): # glm4 family modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward - convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) + convert_forward(model, module.RMSNorm, rms_norm_forward) if hasattr(model.transformer, "vision"): # glm4 vision family @@ -1448,8 +1444,8 @@ def _optimize_post(model): elif model.config.model_type == "baichuan": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward - convert_forward(model, module.MLP, baichuan_mlp_forward) + convert_forward(model, module.RMSNorm, rms_norm_forward) + convert_forward(model, module.MLP, mlp_silu_forward) if model.config.hidden_size in [4096, 2048]: # baichuan-7B and baichuan2-7B @@ -1458,7 +1454,6 @@ def _optimize_post(model): for i in range(len(model.model.layers)): setattr(model.model.layers[i].self_attn, "layer_idx", i) convert_forward(model, module.Attention, baichuan_attention_forward_7b) - convert_forward(model, module.RMSNorm, rms_norm_forward) if model.config.vocab_size == 125696: # baichuan2-7B convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) @@ -1468,9 +1463,7 @@ def _optimize_post(model): elif model.config.hidden_size == 5120: # baichuan-13B and baichuan2-13B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b - from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b) - convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward) if model.config.vocab_size == 125696: # baichaun2-13B @@ -1565,7 +1558,6 @@ def _optimize_post(model): from ipex_llm.transformers.models.qwen import qwen_attention_forward from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered from ipex_llm.transformers.models.qwen import qwen_mlp_forward - from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.qwen import qwen_model_forward if model.config.max_position_embeddings == 8192 \ and model.config.hidden_size == 4096: @@ -1580,7 +1572,7 @@ def _optimize_post(model): ) convert_forward(model, module.RMSNorm, - chatglm_rms_norm_forward) + rms_norm_forward) convert_forward(model, module.QWenMLP, qwen_mlp_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index a78e5f8e131..ad0c780adda 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module): module.register_buffer("inv_freq", inv_freq, persistent=False) -def baichuan_13b_rms_norm_forward(self, hidden_states): - if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad): - import xe_addons - x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() - output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon) - return output.reshape(hidden_states.shape) - - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def baichuan_mlp_forward( - self, - x: torch.Tensor, -) -> torch.Tensor: - x_2d = x.view(-1, x.shape[-1]) - qtype = getattr(self.gate_proj, "qtype", None) - if mlp_fusion_check(x_2d, qtype, self.training): - import xe_linear - if not x_2d.is_contiguous(): - x_2d = x_2d.contiguous() - return self.down_proj(xe_linear.mlp_forward_xpu( - x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, - x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, - SILU, qtype - )) - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - def baichuan_model_7b_forward( self, input_ids: torch.LongTensor = None, diff --git a/python/llm/src/ipex_llm/transformers/models/bert.py b/python/llm/src/ipex_llm/transformers/models/bert.py index 4c83ba6c7a8..810d89b4e01 100644 --- a/python/llm/src/ipex_llm/transformers/models/bert.py +++ b/python/llm/src/ipex_llm/transformers/models/bert.py @@ -36,24 +36,13 @@ import torch from typing import Optional, Tuple from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder +from ipex_llm.transformers.models.common import merge_linear from ipex_llm.utils.common import invalidInputError def merge_qkv(module: torch.nn.Module): if isinstance(module, BertSelfAttention): - q_w = module.query.weight.data - k_w = module.key.weight.data - v_w = module.value.weight.data - q_b = module.query.bias.data - k_b = module.key.bias.data - v_b = module.value.bias.data - new_w = torch.cat([q_w, k_w, v_w], dim=0) - new_b = torch.cat([q_b, k_b, v_b], dim=-1) - qkv = torch.nn.Linear(0, 0, bias=True) - qkv.weight = torch.nn.Parameter(new_w, requires_grad=False) - qkv.bias = torch.nn.Parameter(new_b, requires_grad=False) - qkv.in_features = module.query.in_features - qkv.out_features = module.query.out_features * 3 + qkv = merge_linear([module.query, module.key, module.value]) module.qkv = qkv del module.query del module.key diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index beb3653a6b5..236d02518fb 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -33,34 +33,6 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states - go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def chatglm_rms_norm_forward(self, hidden_states): - if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - import xe_addons - x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() - output = xe_addons.rms_norm(self.weight, x_2d, self.eps) - return output.reshape(hidden_states.shape) - - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - return self.weight * hidden_states.to(input_dtype) - - def chatglm2_model_forward( self, input_ids, diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 02be8a51eaf..29520b44ec9 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor): weight = self.weight if hasattr(self, "variance_epsilon"): eps = self.variance_epsilon - else: + elif hasattr(self, "epsilon"): eps = self.epsilon + else: + eps = self.eps if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]: import xe_addons