diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index fece88384b5..75d6a0cea76 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device) # compute - # import xe_addons - # if is_causal: - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) - # else: - # attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) - # elif seq_length != kv_length and seq_length <= 32: - # # todo: add scale support - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8(query, key, value, mask) - # else: - # attn_output = xe_addons.sdp(query, key, value, mask) - # else: - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) - # else: - # attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) - import xe_addons if is_causal: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask) + attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) else: - attn_output = xe_addons.sdp_causal(query, key, value, mask) + attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) elif seq_length != kv_length and seq_length <= 32: + # todo: add scale support if key.dtype == torch.uint8: attn_output = xe_addons.sdp_fp8(query, key, value, mask) else: attn_output = xe_addons.sdp(query, key, value, mask) else: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask) + attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale) else: - attn_output = xe_addons.sdp_non_causal(query, key, value, mask) + attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) return attn_output else: