diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index e514d3d3cac..f292a80cf23 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -27,6 +27,7 @@ from types import SimpleNamespace from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ipex_llm.utils.common import invalidInputError +from ipex_llm.ggml.quantize import ggml_tensor_qtype import logging logger = logging.getLogger(__name__) import asyncio @@ -106,6 +107,29 @@ def init_pipeline_parallel(): dist.init_process_group('ccl') +def _check_quantize_kv_cache(model, idx, batch_size): + # align use_quantize_kv_cache setting for different GPU in pipeline parallel + pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \ + (os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) == "1") or \ + (os.environ.get("IPEX_LLM_LOW_MEM", None) == "1") + if model.config.model_type == "qwen" and hasattr(model.config, "visual"): + # for Qwen-VL-Chat + linear = model._modules['transformer'].h[idx].mlp.c_proj + elif model.config.model_type == "chatglm": + # for chatglm3-6b, glm-4-9b-chat + linear = model._modules['transformer'].encoder.layers[idx].self_attention.query_key_value + else: + linear = model._modules['model'].layers[idx].mlp.up_proj + pp_quantize_kv_cache = pp_quantize_kv_cache or (1 < batch_size and batch_size <= 8 and + hasattr(linear, "qtype") and + linear.qtype != ggml_tensor_qtype["fp16"] and + linear.qtype != ggml_tensor_qtype["bf16"]) + if pp_quantize_kv_cache: + os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "1" + else: + os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0" + + def pipeline_parallel(model, pipeline_parallel_stages): global num_layers if hasattr(model.config, 'num_hidden_layers'): @@ -255,6 +279,7 @@ def pipeline_parallel_generate(self, _past_key_values = None bs = inputs.shape[0] output_ids = inputs.clone() + _check_quantize_kv_cache(self, layer_start, bs) step = 0 # keep track of which sequences are already finished