Skip to content

Commit

Permalink
use new rotary two in chatglm4
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Jun 13, 2024
1 parent f1410d6 commit bd47da2
Showing 1 changed file with 32 additions and 45 deletions.
77 changes: 32 additions & 45 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ def chatglm4_model_forward(
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
# if use_cache and use_quantize_kv_cache(
# self.encoder.layers[0].self_attention.query_key_value, input_ids):
# if not isinstance(past_key_values, DynamicFp8Cache):
# past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return chatglm4_model_forward_internal(
self=self,
input_ids=input_ids,
Expand Down Expand Up @@ -108,25 +104,17 @@ def chatglm4_model_forward_internal(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)

use_fuse_rope = input_ids.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not self.training

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
if use_fuse_rope:
# Repeat cos sin here, call only once for each token.
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
# If put this to attension forward, it will generate too many times.
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
cos = cos.squeeze(-1)
sin = sin.squeeze(-1)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
rotary_pos_emb = (cos, sin)
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
# We should generate float inv_freq to avoid overflow, as base is too large.
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
self.rotary_pos_emb.register_buffer("inv_freq",
inv_freq.to(inputs_embeds.dtype),
persistent=False)
self.rotary_pos_emb.cached = True

# `full_attention_mask` is not None only when
# `past_key_values` is not None and `seq_length` > 1
Expand All @@ -148,7 +136,7 @@ def chatglm4_model_forward_internal(

hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, causal_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb=(self.rotary_pos_emb.inv_freq, position_ids),
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
)
# ipex-llm changes end
Expand Down Expand Up @@ -209,34 +197,33 @@ def chatglm4_attention_forward(
qkv = self.query_key_value(hidden_states)
# [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim]
qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim)
qkv = qkv.transpose(1, 2)

query_states, key_states, value_states = qkv.split([n_head,
n_kv_head,
n_kv_head], dim=2)
n_kv_head], dim=1)

kv_seq_len = key_states.shape[1]
kv_seq_len = key_states.shape[2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]

if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
# use_fuse_rope, see chatglm4_model_forward
cos, sin = rotary_pos_emb
rot_dim = cos.shape[-1]
query_layer_cur = query_states[..., :rot_dim]
key_layer_cur = key_states[..., :rot_dim]
# ipex_llm's apply_rotary_embedding can change the origin storage,
# so query_layer will get the result directly.
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
elif rotary_pos_emb is not None:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
# IPEX-LLM OPT: fuse rope
inv_freq, position_ids = rotary_pos_emb
rot_dim = inv_freq.size(-1) * 2
if should_use_fuse_rope(hidden_states, rotary_pos_emb[1], self.training):
import xe_addons
xe_addons.rotary_two_inplaced(inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim])
else:
idx_theta = torch.outer(position_ids[0].float(),
inv_freq.float()).to(hidden_states.dtype)
idx_theta = idx_theta.unsqueeze(0).unsqueeze(0)
cos = torch.cos(idx_theta).repeat_interleave(2, -1)
sin = torch.sin(idx_theta).repeat_interleave(2, -1)
q_rot, k_rot = apply_rotary_pos_emb(query_states[..., :rot_dim], key_states[..., :rot_dim],
cos, sin, position_ids, "chatglm")
query_states[..., :rot_dim] = q_rot[...]
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, hidden_states)
Expand Down

0 comments on commit bd47da2

Please sign in to comment.