From 33852bd23e171e924e9b47adeb6423da84540f6d Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Tue, 28 May 2024 16:52:46 +0800 Subject: [PATCH] Refactor pipeline parallel device config (#11149) * refactor pipeline parallel device config * meet comments * update example * add warnings and update code doc --- .../Pipeline-Parallel-Inference/generate.py | 23 ++----------- python/llm/src/ipex_llm/transformers/model.py | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py index badba4a9163..84f625a1776 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py @@ -62,27 +62,8 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], load_in_4bit=True, optimize_model=True, trust_remote_code=True, - use_cache=True) - - model_layers = ['model.embed_tokens'] - for i in range(model.config.num_hidden_layers): - model_layers.append(f'model.layers.{i}') - model_layers = model_layers + ['model.norm', 'lm_head'] - - device_map = {} - split_len = len(model_layers) // args.gpu_num - for i in range(args.gpu_num): - device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]}) - if i == args.gpu_num - 1: - device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]}) - - from accelerate import dispatch_model - model = dispatch_model( - model, - device_map=device_map, - offload_dir=None, - skip_keys=["past_key_value", "past_key_values"], - ) + use_cache=True, + pipeline_parallel_stages=args.gpu_num) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 72153f4cb3d..21a1b1ea357 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs): self.to(origin_device) +def pipeline_parallel(model, pipeline_parallel_stages): + model_layers = ['model.embed_tokens'] + for i in range(model.config.num_hidden_layers): + model_layers.append(f'model.layers.{i}') + model_layers = model_layers + ['model.norm', 'lm_head'] + + device_map = {} + split_len = len(model_layers) // pipeline_parallel_stages + for i in range(pipeline_parallel_stages): + device_map.update({key: f'xpu:{i}' for key in + model_layers[split_len * i: split_len * (i + 1)]}) + if i == pipeline_parallel_stages - 1: + device_map.update({key: f'xpu:{i}' for key in + model_layers[split_len * (i + 1):]}) + + from accelerate import dispatch_model + model = dispatch_model( + model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"], + ) + return model + + def _load_pre(): from transformers import GPTJModel from ipex_llm.transformers.models.gptj import gptj_model_new_init @@ -157,6 +179,9 @@ def from_pretrained(cls, :param mixed_precision: boolean value, Whether to use mixed precision quantization. Default to be False. If set to True, we will use sym_int8 for lm_head when load_in_low_bit is sym_int4 or asym_int4. + :param pipeline_parallel_stages: int value, the number of GPUs allocated for + pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1 + to run pipeline parallel inference on multiple GPUs. :return: a model instance """ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ @@ -190,6 +215,7 @@ def from_pretrained(cls, optimize_model = kwargs.pop("optimize_model", True) user_quantization_config = kwargs.pop("quantization_config", None) speculative = kwargs.pop("speculative", False) + pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1) torch_dtype = kwargs.pop("torch_dtype", None) embedding_qtype = kwargs.pop("embedding_qtype", None) @@ -346,6 +372,13 @@ def from_pretrained(cls, kwargs["embedding_qtype"] = embedding_qtype model = cls.load_convert(q_k, optimize_model, *args, **kwargs) + if pipeline_parallel_stages > 1: + if speculative: + invalidInputError(False, + f"Please do not set speculative=True" + f" when using pipeline_parallel_stages") + model = pipeline_parallel(model, pipeline_parallel_stages) + if speculative: from .speculative import speculative_generate, clear_benchmarks,\ _crop_past_key_values