From 21dbf794c5207d0326255d3f5ad8ceccacb28f66 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 13 Dec 2024 13:45:23 +0800 Subject: [PATCH] optimize glm edge again --- python/llm/src/ipex_llm/transformers/models/glm.py | 9 +++++++-- python/llm/src/ipex_llm/transformers/models/utils.py | 8 ++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index 485a449d294..00d32790954 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -41,6 +41,7 @@ from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache @@ -92,9 +93,13 @@ def glm_attention_forward( query_states, key_states, value_states = qkv.split([self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=1) - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if query_states.device.type == "xpu": + import xe_addons + make_cache_contiguous_inplaced(cos, sin) + xe_addons.rotary_two_with_cache_inplaced(query_states, key_states, cos, sin, True) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 351ce689de3..1589bea4403 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -493,3 +493,11 @@ def get_q_proj_or_qkv_proj(self): elif hasattr(self, "qkv_proj"): proj = self.qkv_proj return proj + + +def make_cache_contiguous_inplaced(cos: torch.Tensor, sin: torch.Tensor): + if not cos.is_contiguous(): + new_cos = cos.contiguous() + new_sin = sin.contiguous() + cos.set_(new_cos) + sin.set_(new_sin)