diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index 9069c49c0c5..b63772a8e41 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -51,8 +51,7 @@ from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\ - apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.mistral import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp @@ -258,16 +257,9 @@ def mixtral_attention_forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, - key_states, - sin, - cos, - "mixtral") + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1589bea4403..904ffe9b727 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -207,36 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin): torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k) -def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None): - if q.device.type != "xpu": - invalidInputError(False, - f"only xpu is supported in this function") - import xe_addons - q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) - k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) - if model_family in ["qwen", "mixtral"]: - xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, - q_embed, k_embed) - elif model_family in ["qwen2", "yuan", "stablelm", "qwen2_moe", "internlm"]: - cos = cos.to(q.dtype) - sin = sin.to(q.dtype) - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, - q_embed, k_embed) - elif model_family in ["gemma", "phi3"]: - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - xe_addons.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, - q_embed, k_embed) - else: - invalidInputError(False, - f"{model_family} is not supported.") - return q_embed, k_embed - - def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1): # to determinate if is enough kv cache room in transformers==4.36 # seq_len for current seq len