diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 1a71663dd9c..6bfbf460d04 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -75,7 +75,7 @@ def siglip_attention_forward( attn_weights = None attn_output = scaled_dot_product_attention( - query_states, key_states, value_states, + query_states, key_states.contiguous(), value_states.contiguous(), attention_mask, False, 1 / math.sqrt(self.head_dim) ) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 011d6c22d03..d9c04a80e0b 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -583,8 +583,7 @@ def qwen2_attention_forward( self.layer_idx, None) attn_weights = None - if query_states.device.type == 'xpu' \ - and use_flash_attention(query_states, key_states, attention_mask): + if use_flash_attention(query_states, key_states, attention_mask): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index b310b1d277a..71a63366835 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -43,8 +43,9 @@ import torch from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope +from ipex_llm.transformers.models.common import scaled_dot_product_attention +from ipex_llm.transformers.models.utils import use_quantize_kv_cache +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_sdp_non_causal from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common import invalidInputError @@ -198,7 +199,6 @@ def qwen2_vision_attention_forward( "unexpected input") if use_sdp_non_causal(self.head_dim, q.device, q.dtype): - import xe_addons image_num = len(seq_lens) - 1 image_size = seq_lens[1] - seq_lens[0] guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size, @@ -209,7 +209,10 @@ def qwen2_vision_attention_forward( v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # q, k, v: [image_num, num_heads, image_size, head_dim] - attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None) + attn_output = scaled_dot_product_attention( + q, k.contiguous(), v.contiguous(), + None, False + ) attn_output = attn_output.permute(0, 2, 1, 3).contiguous() attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim) # attn_output: [seq_length, num_heads, head_dim] @@ -226,7 +229,10 @@ def qwen2_vision_attention_forward( tmp_q = q[:, :, start_idx:end_idx, :] tmp_k = k[:, :, start_idx:end_idx, :] tmp_v = v[:, :, start_idx:end_idx, :] - attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None) + attn_output = scaled_dot_product_attention( + tmp_q, tmp_k, tmp_v, + None, False + ) attn_output = attn_output.permute(0, 2, 1, 3) # attn_output: [1, seq_length, num_heads, head_dim] attn_outputs.append(attn_output) @@ -293,42 +299,11 @@ def qwen2_vl_attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) - kv_seq_len = key_states.size(2) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :kv_seq_len] - attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, causal_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, causal_mask) - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if causal_mask is not None: - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == key_states.size(2) + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) diff --git a/python/llm/src/ipex_llm/transformers/models/sd.py b/python/llm/src/ipex_llm/transformers/models/sd.py index 4ba360b17ae..06109f15f95 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd.py +++ b/python/llm/src/ipex_llm/transformers/models/sd.py @@ -37,8 +37,8 @@ from typing import Optional from ipex_llm.transformers.utils import get_xpu_device_type -from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax -from ipex_llm.transformers.models.utils import use_sdp_non_causal +from ipex_llm.transformers.models.common import padding_qkv_hd +from ipex_llm.transformers.models.common import scaled_dot_product_attention from diffusers.models.attention_processor import Attention @@ -110,19 +110,10 @@ def __call__( if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]: # padding head_dim 40 to 64 query, key, value = padding_qkv_hd(query, key, value, 40, 64) - - if use_sdp_non_causal(query.size(-1), query.device, query.dtype): - import xe_addons - hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), - value.contiguous(), attention_mask) - else: - scale = 1 / math.sqrt(head_dim) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights) - hidden_states = torch.matmul(attn_weights, value) - + hidden_states = scaled_dot_product_attention( + query, key.contiguous(), value.contiguous(), + attention_mask, False, 1 / math.sqrt(head_dim) + ) hidden_states = hidden_states[:, :, :, :head_dim] else: hidden_states = torch.nn.functional.scaled_dot_product_attention(