diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 9ece100c9a9..86aeaba134f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -217,7 +217,7 @@ def chatglm4_attention_forward( kv_seq_len = key_states.shape[1] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[2] - + if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2: # use_fuse_rope, see chatglm4_model_forward cos, sin = rotary_pos_emb @@ -289,8 +289,8 @@ def chatglm4_attention_forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, n_head // n_kv_head) value_states = repeat_kv(value_states, n_head // n_kv_head) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)).to(value_states.dtype) / math.sqrt(head_dim) + attn_weights = torch.matmul(query_states / math.sqrt(head_dim), + key_states.transpose(2, 3)).to(value_states.dtype) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,