From 61cad54148193ec8aed7233f8d6f3165451956aa Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Wed, 12 Jun 2024 20:51:06 +0800 Subject: [PATCH] update --- .../src/ipex_llm/transformers/models/chatglm2.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 02de358b431..983a6533e89 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -62,16 +62,11 @@ def split_tensor_along_last_dim( def glm_sdpa(query, key, value, attention_mask=None, is_causal=False): if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu': - if attention_mask is None: - context_layer = F.scaled_dot_product_attention(query.to(key.dtype), - key, - value, - is_causal=True).to(key.dtype) - else: - context_layer = F.scaled_dot_product_attention(query.to(key.dtype), - key, - value, - attention_mask).to(key.dtype) + context_layer = F.scaled_dot_product_attention(query.to(key.dtype), + key, + value, + attention_mask, + is_causal=is_causal).to(key.dtype) else: # attention_mask is not None only when past_key_value is not None and q_len > 1 if attention_mask is not None: