Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Jun 13, 2024
1 parent 746bac5 commit 37fb718
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def chatglm4_attention_forward(
kv_seq_len = key_states.shape[1]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]

if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
# use_fuse_rope, see chatglm4_model_forward
cos, sin = rotary_pos_emb
Expand Down Expand Up @@ -289,8 +289,8 @@ def chatglm4_attention_forward(
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, n_head // n_kv_head)
value_states = repeat_kv(value_states, n_head // n_kv_head)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)).to(value_states.dtype) / math.sqrt(head_dim)
attn_weights = torch.matmul(query_states / math.sqrt(head_dim),
key_states.transpose(2, 3)).to(value_states.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
Expand Down

0 comments on commit 37fb718

Please sign in to comment.