Skip to content

Commit

Permalink
fix sd1.5 (#12129)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Sep 26, 2024
1 parent a266528 commit 669ff1a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/sd15.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def __call__(
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# IPEX-LLM changes start
if head_dim in [40, 80]:
import xe_test
hidden_states = xe_test.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
else:
scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
Expand Down

0 comments on commit 669ff1a

Please sign in to comment.