diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 2d5597c2084..9bc41e89ebb 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -118,7 +118,8 @@ def llama_model_forward_4_36( use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) elif should_use_compresskv(input): @@ -157,7 +158,8 @@ def llama_model_forward_4_38( use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) elif should_use_compresskv(input): @@ -197,7 +199,8 @@ def llama_model_forward_4_41( use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) elif should_use_compresskv(input): @@ -425,7 +428,7 @@ def llama_attention_forward_4_31( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = llama_attention_forward_4_31_quantized else: forward_function = llama_attention_forward_4_31_original @@ -1027,7 +1030,7 @@ def llama_attention_forward_4_41( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = llama_attention_forward_4_41_quantized else: forward_function = llama_attention_forward_4_41_original @@ -1566,7 +1569,7 @@ def llama_attention_forward_4_38( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = llama_attention_forward_4_38_quantized else: forward_function = llama_attention_forward_4_38_original diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 80efde30765..29db361c6ba 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -83,7 +83,7 @@ def minicpm_attention_forward( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = minicpm_attention_forward_quantized else: forward_function = minicpm_attention_forward_original @@ -603,7 +603,9 @@ def minicpm_model_forward( from ipex_llm.transformers.kv import DynamicFp8Cache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads // + self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return minicpm_model_forward_internal( @@ -1051,7 +1053,7 @@ def minicpm_attention_forward_4_39( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = minicpm_attention_forward_quantized else: forward_function = minicpm_attention_forward_original_4_39 diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 3f3aa17438d..93825592016 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -205,7 +205,8 @@ def mistral_model_forward_4_36( from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids, + self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) elif should_use_compresskv(input_ids): @@ -237,7 +238,7 @@ def mistral_attention_forward( use_cache: bool=False, padding_mask: Optional[torch.Tensor]=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = mistral_attention_forward_quantized else: forward_function = mistral_attention_forward_original @@ -654,7 +655,7 @@ def mistral_attention_forward_4_36( use_cache: bool=False, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = mistral_attention_forward_4_36_quantized else: forward_function = mistral_attention_forward_4_36_original @@ -1110,7 +1111,7 @@ def mistral_attention_forward_4_39( use_cache: bool=False, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): + if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): forward_function = mistral_attention_forward_4_36_quantized else: forward_function = mistral_attention_forward_4_39_original diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 60191943f50..4b0ad99c75c 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -114,7 +114,8 @@ def qwen2_model_forward( inputs = input_ids if input_ids is not None else inputs_embeds use_quantize_kv = ( self.config.hidden_size != 3584 # disable quantize kv in specific model - and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs) + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, + self.config.num_attention_heads//self.config.num_key_value_heads) ) if use_cache: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index c4626bc9f40..5ebdeaa6bb9 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -75,7 +75,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): return new_cache_k, new_cache_v -def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: +def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: warnings.warn( "`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. " @@ -87,13 +87,13 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: return os.environ["IPEX_LLM_LOW_MEM"] == "1" else: - return x.device.type == 'xpu' and kv_cache_device_check(x) \ + return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \ and hasattr(linear, "qtype") and \ linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] -def kv_cache_device_check(x: torch.Tensor) -> bool: - return get_xpu_device_type(x) == "mtl" or \ +def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool: + return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \ ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and 1 < x.size(0) and x.size(0) <= 8)