Skip to content

Commit

Permalink
MiniCPM-V support compresskv (#11779)
Browse files Browse the repository at this point in the history
* fix check error

* fix other models

* remove print
  • Loading branch information
cyita authored Aug 13, 2024
1 parent 3998de1 commit 7cd6ec9
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def chatglm2_model_forward(
dtype=inputs_embeds.dtype, device=inputs_embeds.device)

if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1])
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
input_ids)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def chatglm4_model_forward(

if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
inputs)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def llama_model_forward_4_36(
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]):
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
Expand Down Expand Up @@ -168,7 +168,7 @@ def llama_model_forward_4_38(
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]):
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
Expand Down Expand Up @@ -209,7 +209,7 @@ def llama_model_forward_4_41(
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]):
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def minicpm_model_forward(
self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]):
elif should_use_compresskv(input, input.shape[1]):
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def mistral_model_forward_4_36(
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input_ids, input_ids.shape[-1]):
elif should_use_compresskv(input_ids, input_ids.shape[1]):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
use_compress_kv = should_use_compresskv(input, input.shape[-1])
use_compress_kv = should_use_compresskv(input, input.shape[1])
if use_cache:
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def qwen2_model_forward(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])

if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
Expand Down

0 comments on commit 7cd6ec9

Please sign in to comment.