Skip to content

Commit

Permalink
optimize glm edge again
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 13, 2024
1 parent fa261b8 commit 21dbf79
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache


Expand Down Expand Up @@ -92,9 +93,13 @@ def glm_attention_forward(
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if query_states.device.type == "xpu":
import xe_addons
make_cache_contiguous_inplaced(cos, sin)
xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down
8 changes: 8 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self):
elif hasattr(self, "qkv_proj"):
proj = self.qkv_proj
return proj


def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor):
if not cos.is_contiguous():
new_cos = cos.contiguous()
new_sin = sin.contiguous()
cos.set_(new_cos)
sin.set_(new_sin)

0 comments on commit 21dbf79

Please sign in to comment.