From c67d363184b8ee587bab824305cb6fd569c2e1c6 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 8 Feb 2024 17:04:59 +0800 Subject: [PATCH] quick fix qwen2 fp8 kv cache (#10135) --- python/llm/src/bigdl/llm/transformers/models/qwen2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index de9ccb61039..e71a1df6299 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized( if q_len != 1: key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) + key = repeat_kv(key, self.num_key_value_groups) + value = repeat_kv(value, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key.transpose(2, 3)) else: import linear_q4_0