Skip to content

Commit

Permalink
Fix compresskv with lookahead issue (#11767)
Browse files Browse the repository at this point in the history
* fix compresskv + lookahead attn_mask qwen2

* support llama chatglm

* support mistral & chatglm

* address comments

* revert run.py
  • Loading branch information
cyita authored Aug 12, 2024
1 parent f97a77e commit 841dbcd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 15 deletions.
7 changes: 6 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def chatglm2_model_forward(
if past_key_values is None:
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
else:
kv_length = past_key_values[0][0].size(0)
if isinstance(past_key_values, DynamicCompressCache):
kv_length = past_key_values.get_seq_length()
else:
kv_length = past_key_values[0][0].size(0)
position_ids = torch.arange(kv_length, kv_length + seq_length,
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)
Expand Down Expand Up @@ -300,6 +303,8 @@ def chatglm2_attention_forward(
attn_weights = None
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
import xe_addons
if use_compresskv and attention_mask is not None:
attention_mask = None
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
else:
Expand Down
10 changes: 8 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.chatglm2 import repeat_kv
from ipex_llm.transformers.kv import DynamicCompressCache
Expand Down Expand Up @@ -79,7 +80,10 @@ def chatglm4_model_forward(
if past_key_values is None:
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
else:
kv_length = past_key_values[0][0].size(2)
if isinstance(past_key_values, DynamicCompressCache):
kv_length = past_key_values.get_seq_length()
else:
kv_length = past_key_values[0][0].size(2)
position_ids = torch.arange(kv_length, kv_length + seq_length,
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)
Expand Down Expand Up @@ -232,6 +236,8 @@ def chatglm4_attention_forward(
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
attention_mask)
Expand Down
13 changes: 8 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import SILU
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
Expand Down Expand Up @@ -1547,9 +1548,10 @@ def llama_attention_forward_4_41_original(
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
# [CompressKV]
if use_compresskv:
# [CompressKV] set attention_mask = None
new_attention_mask = None
new_attention_mask = get_compresskv_attn_mask(key_states,
new_attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)
attn_output = attn_output.view(query_states.shape)
Expand Down Expand Up @@ -2111,9 +2113,10 @@ def llama_attention_forward_4_38_original(
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
# [CompressKV]
if use_compresskv:
# [CompressKV] set attention_mask = None
new_attention_mask = None
new_attention_mask = get_compresskv_attn_mask(key_states,
new_attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)
attn_output = attn_output.view(query_states.shape)
Expand Down
11 changes: 6 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
Expand Down Expand Up @@ -1097,9 +1098,9 @@ def mistral_attention_forward_4_36_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [CompressKV] set attention_mask = None
# [CompressKV]
if use_compresskv:
attention_mask = None
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
Expand Down Expand Up @@ -1348,9 +1349,9 @@ def mistral_attention_forward_4_39_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [CompressKV] set attention_mask = None
# [CompressKV]
if use_compresskv:
attention_mask = None
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
should_use_compresskv, is_enough_kv_cache_room_4_36
should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
from ipex_llm.utils.common import invalidInputError
Expand Down Expand Up @@ -473,7 +473,7 @@ def qwen2_attention_forward(
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_compresskv:
attention_mask = None
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
Expand Down
7 changes: 7 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,13 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
return x.device.type == 'xpu' and use_compress_kv == "1"


def get_compresskv_attn_mask(key_states: torch.Tensor,
attention_mask: torch.Tensor):
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, -key_states.size(2):]
return attention_mask


def get_q_proj_or_qkv_proj(self):
if hasattr(self, "q_proj"):
proj = self.q_proj
Expand Down

0 comments on commit 841dbcd

Please sign in to comment.