From 3f0c76727ffb2a7fc79c9c296990b2fd11e5a143 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 26 Dec 2024 11:13:39 +0800 Subject: [PATCH] support passing attn_scale to sdpa --- .../ipex_llm/transformers/models/common.py | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index fece88384b5..75d6a0cea76 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device) # compute - # import xe_addons - # if is_causal: - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) - # else: - # attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) - # elif seq_length != kv_length and seq_length <= 32: - # # todo: add scale support - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8(query, key, value, mask) - # else: - # attn_output = xe_addons.sdp(query, key, value, mask) - # else: - # if key.dtype == torch.uint8: - # attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) - # else: - # attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) - import xe_addons if is_causal: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask) + attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) else: - attn_output = xe_addons.sdp_causal(query, key, value, mask) + attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) elif seq_length != kv_length and seq_length <= 32: + # todo: add scale support if key.dtype == torch.uint8: attn_output = xe_addons.sdp_fp8(query, key, value, mask) else: attn_output = xe_addons.sdp(query, key, value, mask) else: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask) + attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale) else: - attn_output = xe_addons.sdp_non_causal(query, key, value, mask) + attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) return attn_output else: