diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index c02bfb3d82c..85f58f61427 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1143,8 +1143,17 @@ def llama_attention_forward_4_41_quantized( if len(past_key_value.key_cache) <= self.layer_idx: repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) - if should_split_qkv_tensor(query_states, bsz, self.num_heads, - q_len, kv_seq_len, output_attentions): + if use_cache: + cache_kwargs = None + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + elif should_split_qkv_tensor(query_states, bsz, self.num_heads, + q_len, kv_seq_len, output_attentions): attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, repeated_value_states, attention_mask, cache_position, @@ -1184,10 +1193,6 @@ def llama_attention_forward_4_41_quantized( attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, repeated_value_states) - if use_cache: - cache_kwargs = None - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) else: cache_kwargs = None # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index a3ed08def23..bc66b77fc8f 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -52,7 +52,8 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ + use_sdp_causal from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv @@ -599,6 +600,15 @@ def mistral_attention_forward_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(), + value_states.contiguous(), attention_mask) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons @@ -1052,6 +1062,15 @@ def mistral_attention_forward_4_36_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(), + value_states.contiguous(), attention_mask) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons