diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index f7531561b64..57763d8f98c 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized( kv_seq_len = key_states.shape[-2] if len(past_key_value.key_cache) <= self.layer_idx: + repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) + repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): invalidInputError( @@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, repeated_value_states) if use_cache: cache_kwargs = None key_states, value_states = past_key_value.update(key_states, value_states, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 241686935e5..8f3b98a8da1 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool): def fp16_fusion_check(proj, x, training): # only use fp16 fusion on PVC inference + if not hasattr(proj, "qtype"): + return False if proj.qtype != ggml_tensor_qtype["fp16"]: return False if proj.weight_type != 2: