Skip to content

Commit

Permalink
support passing attn_scale to sdpa (#12619)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 26, 2024
1 parent 40a7d2b commit a9abde0
Showing 1 changed file with 5 additions and 22 deletions.
27 changes: 5 additions & 22 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)

# compute
# import xe_addons
# if is_causal:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
# elif seq_length != kv_length and seq_length <= 32:
# # todo: add scale support
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
# else:
# attn_output = xe_addons.sdp(query, key, value, mask)
# else:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)

import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask)
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)

return attn_output
else:
Expand Down

0 comments on commit a9abde0

Please sign in to comment.