diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index e6a787b67b0..d98974d67ac 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -48,8 +48,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \ get_compresskv_attn_mask -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ - apply_rotary_pos_emb_no_cache_xpu +from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS @@ -64,7 +63,6 @@ except ImportError: Cache = Tuple[torch.Tensor] -from ipex_llm.transformers.low_bit_linear import FP6, FP16 import os @@ -274,8 +272,6 @@ def mistral_attention_forward_quantized( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - if self.q_proj.qtype not in [FP6, FP16]: - use_fuse_rope = False enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) decoding_fast_path = use_decoding_fast_path(self.q_proj, @@ -304,7 +300,8 @@ def mistral_attention_forward_quantized( self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, - self.head_dim) + self.head_dim, + self.rotary_emb.base) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -321,11 +318,9 @@ def mistral_attention_forward_quantized( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral", - self.config.rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -482,8 +477,6 @@ def mistral_attention_forward_original( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - if self.q_proj.qtype not in [FP6, FP16]: - use_fuse_rope = False enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) decoding_fast_path = use_decoding_fast_path(self.q_proj, @@ -506,7 +499,8 @@ def mistral_attention_forward_original( self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len, - self.head_dim) + self.head_dim, + self.rotary_emb.base) kv_seq_len += 1 else: @@ -542,11 +536,9 @@ def mistral_attention_forward_original( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral", - self.config.rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -708,8 +700,6 @@ def mistral_attention_forward_4_36_quantized( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - if self.q_proj.qtype not in [FP6, FP16]: - use_fuse_rope = False enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) @@ -739,7 +729,8 @@ def mistral_attention_forward_4_36_quantized( self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, - self.head_dim) + self.head_dim, + self.rotary_emb.base) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -765,11 +756,9 @@ def mistral_attention_forward_4_36_quantized( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral", - self.config.rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -928,8 +917,6 @@ def mistral_attention_forward_4_36_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - if self.q_proj.qtype not in [FP6, FP16]: - use_fuse_rope = False enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, @@ -958,7 +945,8 @@ def mistral_attention_forward_4_36_original( self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len, - self.head_dim) + self.head_dim, + self.rotary_emb.base) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. @@ -1011,11 +999,9 @@ def mistral_attention_forward_4_36_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral", - self.config.rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -1189,8 +1175,6 @@ def mistral_attention_forward_4_39_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - if self.q_proj.qtype not in [FP6, FP16]: - use_fuse_rope = False enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) @@ -1218,7 +1202,8 @@ def mistral_attention_forward_4_39_original( self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len, - self.head_dim) + self.head_dim, + self.rotary_emb.base) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. @@ -1270,11 +1255,9 @@ def mistral_attention_forward_4_39_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "mistral", - self.config.rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states,