Skip to content

Commit

Permalink
Fix setting of use_quantize_kv_cache on different GPU in pipeline p…
Browse files Browse the repository at this point in the history
…arallel (#11516)
  • Loading branch information
plusbang authored Jul 8, 2024
1 parent 7cb09a8 commit 2524267
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from types import SimpleNamespace
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
import logging
logger = logging.getLogger(__name__)
import asyncio
Expand Down Expand Up @@ -106,6 +107,29 @@ def init_pipeline_parallel():
dist.init_process_group('ccl')


def _check_quantize_kv_cache(model, idx, batch_size):
# align use_quantize_kv_cache setting for different GPU in pipeline parallel
pp_quantize_kv_cache = (os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) == "1") or \
(os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) == "1") or \
(os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
# for Qwen-VL-Chat
linear = model._modules['transformer'].h[idx].mlp.c_proj
elif model.config.model_type == "chatglm":
# for chatglm3-6b, glm-4-9b-chat
linear = model._modules['transformer'].encoder.layers[idx].self_attention.query_key_value
else:
linear = model._modules['model'].layers[idx].mlp.up_proj
pp_quantize_kv_cache = pp_quantize_kv_cache or (1 < batch_size and batch_size <= 8 and
hasattr(linear, "qtype") and
linear.qtype != ggml_tensor_qtype["fp16"] and
linear.qtype != ggml_tensor_qtype["bf16"])
if pp_quantize_kv_cache:
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "1"
else:
os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] = "0"


def pipeline_parallel(model, pipeline_parallel_stages):
global num_layers
if hasattr(model.config, 'num_hidden_layers'):
Expand Down Expand Up @@ -255,6 +279,7 @@ def pipeline_parallel_generate(self,
_past_key_values = None
bs = inputs.shape[0]
output_ids = inputs.clone()
_check_quantize_kv_cache(self, layer_start, bs)

step = 0
# keep track of which sequences are already finished
Expand Down

0 comments on commit 2524267

Please sign in to comment.