Skip to content

Commit

Permalink
disable default quantize_kv of GQA on MTL (#11679)
Browse files Browse the repository at this point in the history
* disable default quantizekv of gqa in mtl

* fix stype

* fix stype

* fix stype

* fix stype

* fix stype

* fix stype
  • Loading branch information
hxsz1997 authored Jul 30, 2024
1 parent c020039 commit 9b36877
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
15 changes: 9 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def llama_model_forward_4_36(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
Expand Down Expand Up @@ -157,7 +158,8 @@ def llama_model_forward_4_38(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
Expand Down Expand Up @@ -197,7 +199,8 @@ def llama_model_forward_4_41(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
Expand Down Expand Up @@ -425,7 +428,7 @@ def llama_attention_forward_4_31(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_31_quantized
else:
forward_function = llama_attention_forward_4_31_original
Expand Down Expand Up @@ -1027,7 +1030,7 @@ def llama_attention_forward_4_41(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_41_quantized
else:
forward_function = llama_attention_forward_4_41_original
Expand Down Expand Up @@ -1566,7 +1569,7 @@ def llama_attention_forward_4_38(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_38_quantized
else:
forward_function = llama_attention_forward_4_38_original
Expand Down
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def minicpm_attention_forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = minicpm_attention_forward_quantized
else:
forward_function = minicpm_attention_forward_original
Expand Down Expand Up @@ -603,7 +603,9 @@ def minicpm_model_forward(
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads //
self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return minicpm_model_forward_internal(
Expand Down Expand Up @@ -1051,7 +1053,7 @@ def minicpm_attention_forward_4_39(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = minicpm_attention_forward_quantized
else:
forward_function = minicpm_attention_forward_original_4_39
Expand Down
9 changes: 5 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def mistral_model_forward_4_36(
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input_ids):
Expand Down Expand Up @@ -237,7 +238,7 @@ def mistral_attention_forward(
use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_quantized
else:
forward_function = mistral_attention_forward_original
Expand Down Expand Up @@ -654,7 +655,7 @@ def mistral_attention_forward_4_36(
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_36_original
Expand Down Expand Up @@ -1110,7 +1111,7 @@ def mistral_attention_forward_4_39(
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_39_original
Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def qwen2_model_forward(
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
)

if use_cache:
Expand Down
8 changes: 4 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v


def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
Expand All @@ -87,13 +87,13 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
else:
return x.device.type == 'xpu' and kv_cache_device_check(x) \
return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \
and hasattr(linear, "qtype") and \
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]


def kv_cache_device_check(x: torch.Tensor) -> bool:
return get_xpu_device_type(x) == "mtl" or \
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
1 < x.size(0) and x.size(0) <= 8)

Expand Down

0 comments on commit 9b36877

Please sign in to comment.