From 15219944b80f0bd31e7d689965c8cf05632f0e5c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 13 Dec 2024 13:52:39 +0800 Subject: [PATCH] optimize glm edge again (#12539) --- python/llm/src/ipex_llm/transformers/models/glm.py | 8 +++++++- python/llm/src/ipex_llm/transformers/models/utils.py | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index 485a449d294..db9b8fd9afc 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -41,6 +41,7 @@ from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache @@ -94,7 +95,12 @@ def glm_attention_forward( self.num_key_value_heads], dim=1) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if query_states.device.type == "xpu": + import xe_addons + make_cache_contiguous_inplaced(cos, sin) + xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 351ce689de3..1589bea4403 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self): elif hasattr(self, "qkv_proj"): proj = self.qkv_proj return proj + + +def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor): + if not cos.is_contiguous(): + new_cos = cos.contiguous() + new_sin = sin.contiguous() + cos.set_(new_cos) + sin.set_(new_sin)