diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 42cf72e353c..3d854c332f7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -213,6 +213,7 @@ def attention(self, value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) + new_value_states = value_states query_states, key_states = self.apply_rotary_pos_emb( q=query_states, @@ -225,7 +226,6 @@ def attention(self, head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2)