Skip to content

Commit

Permalink
refactor to simplify following upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jan 9, 2025
1 parent 5c24276 commit 38e5ff5
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 87 deletions.
18 changes: 5 additions & 13 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,6 @@ def _optimize_post(model):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
Expand All @@ -1338,9 +1337,7 @@ def _optimize_post(model):
convert_forward(model,
module.ChatGLMModel,
chatglm2_model_forward)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_forward)
# for codegeex-nano
if hasattr(model.config, "rope_ratio"):
Expand All @@ -1358,8 +1355,7 @@ def _optimize_post(model):
# glm4 family
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)

if hasattr(model.transformer, "vision"):
# glm4 vision family
Expand Down Expand Up @@ -1448,8 +1444,8 @@ def _optimize_post(model):
elif model.config.model_type == "baichuan":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
convert_forward(model, module.MLP, baichuan_mlp_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_silu_forward)

if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B
Expand All @@ -1458,7 +1454,6 @@ def _optimize_post(model):
for i in range(len(model.model.layers)):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, rms_norm_forward)
if model.config.vocab_size == 125696:
# baichuan2-7B
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
Expand All @@ -1468,9 +1463,7 @@ def _optimize_post(model):
elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)

if model.config.vocab_size == 125696:
# baichaun2-13B
Expand Down Expand Up @@ -1565,7 +1558,6 @@ def _optimize_post(model):
from ipex_llm.transformers.models.qwen import qwen_attention_forward
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.qwen import qwen_model_forward
if model.config.max_position_embeddings == 8192 \
and model.config.hidden_size == 4096:
Expand All @@ -1580,7 +1572,7 @@ def _optimize_post(model):
)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
rms_norm_forward)
convert_forward(model,
module.QWenMLP,
qwen_mlp_forward)
Expand Down
32 changes: 0 additions & 32 deletions python/llm/src/ipex_llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
module.register_buffer("inv_freq", inv_freq, persistent=False)


def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
return output.reshape(hidden_states.shape)

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)


def baichuan_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


def baichuan_model_7b_forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
15 changes: 2 additions & 13 deletions python/llm/src/ipex_llm/transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,13 @@
import torch
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
from ipex_llm.transformers.models.common import merge_linear
from ipex_llm.utils.common import invalidInputError


def merge_qkv(module: torch.nn.Module):
if isinstance(module, BertSelfAttention):
q_w = module.query.weight.data
k_w = module.key.weight.data
v_w = module.value.weight.data
q_b = module.query.bias.data
k_b = module.key.bias.data
v_b = module.value.bias.data
new_w = torch.cat([q_w, k_w, v_w], dim=0)
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
qkv = torch.nn.Linear(0, 0, bias=True)
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
qkv.in_features = module.query.in_features
qkv.out_features = module.query.out_features * 3
qkv = merge_linear([module.query, module.key, module.value])
module.qkv = qkv
del module.query
del module.key
Expand Down
28 changes: 0 additions & 28 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,6 @@
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
return output.reshape(hidden_states.shape)

input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)


def chatglm2_model_forward(
self,
input_ids,
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
weight = self.weight
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
else:
elif hasattr(self, "epsilon"):
eps = self.epsilon
else:
eps = self.eps

if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
Expand Down

0 comments on commit 38e5ff5

Please sign in to comment.