Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Jun 12, 2024
1 parent ff8d157 commit 61cad54
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,11 @@ def split_tensor_along_last_dim(

def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
if attention_mask is None:
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
is_causal=True).to(key.dtype)
else:
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
attention_mask).to(key.dtype)
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
attention_mask,
is_causal=is_causal).to(key.dtype)
else:
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
Expand Down

0 comments on commit 61cad54

Please sign in to comment.