From d2199db4ad1d107ab1e9042ce53c4a51fc156737 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Fri, 13 Dec 2024 11:20:47 +0800 Subject: [PATCH] further fix --- .../llm/src/ipex_llm/transformers/npu_models/mp_models_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 42cf72e353c..3d854c332f7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -213,6 +213,7 @@ def attention(self, value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) + new_value_states = value_states query_states, key_states = self.apply_rotary_pos_emb( q=query_states, @@ -225,7 +226,6 @@ def attention(self, head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2)