diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 983a6533e89..fc77cb89e70 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -92,8 +92,8 @@ def glm_sdpa(query, key, value, attention_mask=None, is_causal=False): context_layer = attn_output.view(query.shape) else: head_dim = query.size(-1) - attn = torch.matmul(query.to(key.dtype), - key.transpose(2, 3)) / math.sqrt(head_dim) + attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim), + key.transpose(2, 3)) if attn_bias is not None: attn += attn_bias attn = F.softmax(attn, dim=-1,