From 669ff1a97be11e0267e502fdbdb2d60049a73a6e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 26 Sep 2024 17:15:16 +0800 Subject: [PATCH] fix sd1.5 (#12129) --- python/llm/src/ipex_llm/transformers/models/sd15.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/sd15.py b/python/llm/src/ipex_llm/transformers/models/sd15.py index 0d8f3532c2b..ab999d40974 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd15.py +++ b/python/llm/src/ipex_llm/transformers/models/sd15.py @@ -106,9 +106,9 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # IPEX-LLM changes start if head_dim in [40, 80]: - import xe_test - hidden_states = xe_test.sdp_non_causal(query, key.contiguous(), - value.contiguous(), attention_mask) + import xe_addons + hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), + value.contiguous(), attention_mask) else: scale = 1 / math.sqrt(head_dim) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))