diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 379cfb41f5b..d8b97139f9f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -171,15 +171,14 @@ def chatglm4_attention_forward( past_key_value = None if kv_cache is None else (kv_cache[0], kv_cache[1]) - n_head = self.num_attention_heads_per_partition #32 - n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head # 2 - head_dim = self.hidden_size_per_attention_head # 128 - + n_head = self.num_attention_heads_per_partition + n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head + head_dim = self.hidden_size_per_attention_head - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states) # [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim] - qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim) - qkv = qkv.transpose(1, 2) + qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim) + qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.split([n_head, n_kv_head, @@ -217,7 +216,7 @@ def chatglm4_attention_forward( if use_cache: if past_key_value is None: past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0), - value_states.unsqueeze(0).unsqueeze(0)), dim=1) + value_states.unsqueeze(0).unsqueeze(0)), dim=1) else: past_key_value = (key_states, value_states) else: