From 841dbcdf3a6379369eee2ae63c59f9cee7537003 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:53:55 +0300 Subject: [PATCH] Fix compresskv with lookahead issue (#11767) * fix compresskv + lookahead attn_mask qwen2 * support llama chatglm * support mistral & chatglm * address comments * revert run.py --- .../src/ipex_llm/transformers/models/chatglm2.py | 7 ++++++- .../src/ipex_llm/transformers/models/chatglm4.py | 10 ++++++++-- .../llm/src/ipex_llm/transformers/models/llama.py | 13 ++++++++----- .../llm/src/ipex_llm/transformers/models/mistral.py | 11 ++++++----- .../llm/src/ipex_llm/transformers/models/qwen2.py | 4 ++-- .../llm/src/ipex_llm/transformers/models/utils.py | 7 +++++++ 6 files changed, 37 insertions(+), 15 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index dcf55e54de8..d8ca9c8eb87 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -108,7 +108,10 @@ def chatglm2_model_forward( if past_key_values is None: position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device) else: - kv_length = past_key_values[0][0].size(0) + if isinstance(past_key_values, DynamicCompressCache): + kv_length = past_key_values.get_seq_length() + else: + kv_length = past_key_values[0][0].size(0) position_ids = torch.arange(kv_length, kv_length + seq_length, dtype=torch.int64, device=inputs_embeds.device) position_ids = position_ids.repeat(batch_size, 1) @@ -300,6 +303,8 @@ def chatglm2_attention_forward( attn_weights = None if use_sdp(q_len, kv_seq_len, head_dim, query_states): import xe_addons + if use_compresskv and attention_mask is not None: + attention_mask = None if use_quantize_kv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 53ec5e74809..4874a8957b2 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ - use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 + use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \ + get_compresskv_attn_mask from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.chatglm2 import repeat_kv from ipex_llm.transformers.kv import DynamicCompressCache @@ -79,7 +80,10 @@ def chatglm4_model_forward( if past_key_values is None: position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device) else: - kv_length = past_key_values[0][0].size(2) + if isinstance(past_key_values, DynamicCompressCache): + kv_length = past_key_values.get_seq_length() + else: + kv_length = past_key_values[0][0].size(2) position_ids = torch.arange(kv_length, kv_length + seq_length, dtype=torch.int64, device=inputs_embeds.device) position_ids = position_ids.repeat(batch_size, 1) @@ -232,6 +236,8 @@ def chatglm4_attention_forward( attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training): import xe_addons + if use_compresskv: + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if use_quantize_kv: attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index e1b2d5f11b1..ccc4afc37ff 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -42,7 +42,8 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import SILU from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv + restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \ + get_compresskv_attn_mask from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -1547,9 +1548,10 @@ def llama_attention_forward_4_41_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + # [CompressKV] if use_compresskv: - # [CompressKV] set attention_mask = None - new_attention_mask = None + new_attention_mask = get_compresskv_attn_mask(key_states, + new_attention_mask) attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) @@ -2111,9 +2113,10 @@ def llama_attention_forward_4_38_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + # [CompressKV] if use_compresskv: - # [CompressKV] set attention_mask = None - new_attention_mask = None + new_attention_mask = get_compresskv_attn_mask(key_states, + new_attention_mask) attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 358cc9ccbc1..f077474fb65 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -46,7 +46,8 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv + restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \ + get_compresskv_attn_mask from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ @@ -1097,9 +1098,9 @@ def mistral_attention_forward_4_36_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons - # [CompressKV] set attention_mask = None + # [CompressKV] if use_compresskv: - attention_mask = None + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None @@ -1348,9 +1349,9 @@ def mistral_attention_forward_4_39_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons - # [CompressKV] set attention_mask = None + # [CompressKV] if use_compresskv: - attention_mask = None + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index df39266258f..b2ec61a3222 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -48,7 +48,7 @@ from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ - should_use_compresskv, is_enough_kv_cache_room_4_36 + should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache from ipex_llm.utils.common import invalidInputError @@ -473,7 +473,7 @@ def qwen2_attention_forward( elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons if use_compresskv: - attention_mask = None + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if isinstance(past_key_value, DynamicFp8Cache): attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index be19f5ac0b5..3b2ab76a393 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -497,6 +497,13 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int): return x.device.type == 'xpu' and use_compress_kv == "1" +def get_compresskv_attn_mask(key_states: torch.Tensor, + attention_mask: torch.Tensor): + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, -key_states.size(2):] + return attention_mask + + def get_q_proj_or_qkv_proj(self): if hasattr(self, "q_proj"): proj = self.q_proj