From d8c62e0bb03f0be84a8d05c799cd95bcb794d0f6 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 00:18:21 +0800 Subject: [PATCH 1/5] initial support --- .../GPU/Pipeline-Parallel-Inference/README.md | 78 ++------ .../Pipeline-Parallel-Inference/generate.py | 50 ++--- .../run_llama2_13b_arc_2_card.sh | 30 +++ .../llm/src/ipex_llm/transformers/__init__.py | 1 + python/llm/src/ipex_llm/transformers/model.py | 32 +--- .../transformers/pipeline_parallel.py | 175 ++++++++++++++++++ 6 files changed, 243 insertions(+), 123 deletions(-) create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh create mode 100644 python/llm/src/ipex_llm/transformers/pipeline_parallel.py diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md index 1f51c5f9741..42e72cc5bea 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -5,90 +5,42 @@ This example demonstrates how to run IPEX-LLM optimized low-bit model vertically ## Requirements To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. -> [!NOTE] -> To run IPEX-LLM on multiple Intel GPUs in pipeline parallel fashion, you will need to install **Intel® oneAPI Base Toolkit 2024.1**, which could be done through an offline installer: -> ```bash -> wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/fdc7a2bc-b7a8-47eb-8876-de6201297144/l_BaseKit_p_2024.1.0.596_offline.sh -> -> sudo sh ./l_BaseKit_p_2024.1.0.596_offline.sh -> ``` - ## Example: Run pipeline parallel inference on multiple GPUs +### 0. Prerequisites + +Please visit the [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html), follow [Install Intel GPU Driver](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-intel-gpu-driver) and [Install oneAPI](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-oneapi) to install GPU driver and Intel® oneAPI Base Toolkit 2024.0. + ### 1. Installation ```bash conda create -n llm python=3.11 conda activate llm - +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -``` - -### 2. Configures OneAPI environment variables - -```bash -source /opt/intel/oneapi/setvars.sh -``` - -> [!NOTE] -> Please make sure you configure the environment variables for **Intel® oneAPI Base Toolkit's version == 2024.1.**. - -### 3 Runtime Configurations - -For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. - -
- -For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series - -```bash -export USE_XETLA=OFF -export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 -export SYCL_CACHE_PERSISTENT=1 +pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ ``` -
+### 2. Run pipeline parallel inference on multiple GPUs -
+For optimal performance, it is recommended to set several environment variables. We provide example usage as following: -For Intel Data Center GPU Max Series +- Run Llama-2-13b-chat-hf on two Intel Arc A770 ```bash -export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so -export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 -export SYCL_CACHE_PERSISTENT=1 -export ENABLE_SDP_FUSION=1 +bash run_llama2_13b_arc_2_card.sh ``` -> [!NOTE] -> Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`. -
-### 4. Running examples -``` -python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT --gpu-num GPU_NUM -``` - -Arguments info: -- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`. -- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. -- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. -- `--gpu-num GPU_NUM`: argument defining the number of GPU to use. It is default to be `2`. +> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine. #### Sample Output -##### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) ```log Inference time: xxxx s -------------------- Prompt -------------------- -[INST] <> - -<> - -What is AI? [/INST] +Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun -------------------- Output -------------------- -[INST] <> - -<> +Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was always asking her parents to take her on trips, but they were always too busy or too tired. -What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence, +One day, the little girl ``` \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py index ae3cedb10ca..7e7736d9b8d 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py @@ -19,34 +19,18 @@ import time import argparse -from ipex_llm.transformers import AutoModelForCausalLM +from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel from transformers import AutoTokenizer -# you could tune the prompt based on your own model, -# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style -DEFAULT_SYSTEM_PROMPT = """\ -""" - -def get_prompt(message: str, chat_history: list[tuple[str, str]], - system_prompt: str) -> str: - texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] - # The first user input is _not_ stripped - do_strip = False - for user_input, response in chat_history: - user_input = user_input.strip() if do_strip else user_input - do_strip = True - texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') - message = message.strip() if do_strip else message - texts.append(f'{message} [/INST]') - return ''.join(texts) +init_pipeline_parallel() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') - parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-13b-chat-hf", help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' ', or the path to the huggingface checkpoint folder') - parser.add_argument('--prompt', type=str, default="What is AI?", + parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", help='Prompt to infer') parser.add_argument('--n-predict', type=int, default=32, help='Max tokens to predict') @@ -66,35 +50,27 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + local_rank = torch.distributed.get_rank() # Generate predicted tokens with torch.inference_mode(): - prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) - input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0') + input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(f'xpu:{local_rank}') # ipex_llm model needs a warmup, then inference time can be accurate output = model.generate(input_ids, - do_sample=False, - max_new_tokens=args.n_predict) - output = model.generate(input_ids, - do_sample=False, max_new_tokens=args.n_predict) # start inference st = time.time() - # if your selected model is capable of utilizing previous key/value attentions - # to enhance decoding speed, but has `"use_cache": false` in its model config, - # it is important to set `use_cache=True` explicitly in the `generate` function - # to obtain optimal performance with IPEX-LLM INT4 optimizations output = model.generate(input_ids, - do_sample=False, max_new_tokens=args.n_predict) torch.xpu.synchronize() end = time.time() output = output.cpu() - output_str = tokenizer.decode(output[0], skip_special_tokens=True) - print(f'Inference time: {end-st} s') - print('-'*20, 'Prompt', '-'*20) - print(prompt) - print('-'*20, 'Output', '-'*20) - print(output_str) + if local_rank == args.gpu_num - 1: + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(args.prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh new file mode 100644 index 00000000000..5924aada001 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh @@ -0,0 +1,30 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source /opt/intel/oneapi/setvars.sh +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=9090 +export FI_PROVIDER=tcp +export USE_XETLA=OFF +export OMP_NUM_THREADS=6 +if [[ $KERNEL_VERSION != *"6.5"* ]]; then + export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +fi +export TORCH_LLM_ALLREDUCE=0 + +NUM_GPUS=2 # number of used GPU +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \ + generate.py --repo-id-or-model-path 'meta-llama/Llama-2-13b-chat-hf' --gpu-num $NUM_GPUS diff --git a/python/llm/src/ipex_llm/transformers/__init__.py b/python/llm/src/ipex_llm/transformers/__init__.py index 02d51f2f712..e95e770454d 100644 --- a/python/llm/src/ipex_llm/transformers/__init__.py +++ b/python/llm/src/ipex_llm/transformers/__init__.py @@ -22,3 +22,4 @@ AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \ AutoModelForTokenClassification from .modelling_bigdl import * +from .pipeline_parallel import init_pipeline_parallel diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 70c2f0d9b2b..8aa002eec47 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -95,28 +95,6 @@ def save_low_bit(self, *args, **kwargs): self.to(origin_device) -def pipeline_parallel(model, pipeline_parallel_stages): - model_layers = ['model.embed_tokens'] - for i in range(model.config.num_hidden_layers): - model_layers.append(f'model.layers.{i}') - model_layers = model_layers + ['model.norm', 'lm_head'] - - device_map = {} - split_len = len(model_layers) // pipeline_parallel_stages - for i in range(pipeline_parallel_stages): - device_map.update({key: f'xpu:{i}' for key in - model_layers[split_len * i: split_len * (i + 1)]}) - if i == pipeline_parallel_stages - 1: - device_map.update({key: f'xpu:{i}' for key in - model_layers[split_len * (i + 1):]}) - - from accelerate import dispatch_model - model = dispatch_model( - model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"], - ) - return model - - def _load_pre(): from transformers import GPTJModel from ipex_llm.transformers.models.gptj import gptj_model_new_init @@ -377,8 +355,16 @@ def from_pretrained(cls, invalidInputError(False, f"Please do not set speculative=True" f" when using pipeline_parallel_stages") + invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages, + "Please make sure you've called `init_pipeline_parallel()` " + "and world size is the same as `pipeline_parallel_stages`") + from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate model = pipeline_parallel(model, pipeline_parallel_stages) - + import types + # add pipeline_parallel_generate to pretrained model dynamically + model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, + model) + torch.distributed.barrier() if speculative: from .speculative import speculative_generate, clear_benchmarks,\ _crop_past_key_values diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py new file mode 100644 index 00000000000..3859e0e4cfd --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -0,0 +1,175 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py +# + +import torch +from torch import nn +import torch.distributed as dist +import os +from typing import Callable, List, Optional +from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList + +# patch GenerationMixin.generate +from transformers import GenerationMixin +original_generate = GenerationMixin.generate + + +class DummyLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ + # python/llm/src/ipex_llm/transformers/models/llama.py#L2076 + self.weight = torch.randn(1,) + + def forward(self, x): + return x + + +class Dummy_MLPLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ + # python/llm/src/ipex_llm/transformers/models/llama.py#L119 + self.up_proj = DummyLayer() + + def forward(self, x): + return x + + +class Dummy_DecoderLayer(nn.Module): + def __init__(self, *args): + super().__init__() + # to avoid AttributeError + self.input_layernorm = DummyLayer() + self.mlp = Dummy_MLPLayer() + + def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs): + outputs = (hidden_states,) + if use_cache: + outputs += (past_key_value,) + return outputs + + +def init_pipeline_parallel(): + import oneccl_bindings_for_pytorch + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + dist.init_process_group('ccl') + + +def pipeline_parallel(model, pipeline_parallel_stages): + slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \ + pipeline_parallel_stages + + local_rank = dist.get_rank() + layer_start = slice_size * local_rank + layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start) + + for i in range(model.config.num_hidden_layers): + if i < layer_start or i >= layer_end: + model._modules['model'].layers[i] = Dummy_DecoderLayer() + else: + # align layer_idx and len(past_key_values), otherwise abnormal output + model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start + + if local_rank != 0: + model._modules['model'].embed_tokens = DummyLayer() + if local_rank != pipeline_parallel_stages - 1: + model._modules['model'].norm = DummyLayer() + model._modules['lm_head'] = DummyLayer() + + model.pipeline_parallel_stages = pipeline_parallel_stages + model = model.to(f'xpu:{local_rank}') + return model + + +@torch.no_grad() +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + if self.pipeline_parallel_stages > 1: + if generation_config is not None and generation_config.max_new_tokens is not None: + max_new_tokens = generation_config.max_new_tokens + else: + max_new_tokens = kwargs.get("max_new_tokens", None) + return self.pipeline_parallel_generate(inputs=inputs, + max_new_tokens=max_new_tokens,) + + return original_generate(self, + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs) + +GenerationMixin.generate = generate + + +@torch.no_grad() +def pipeline_parallel_generate(self, + inputs: Optional[torch.Tensor] = None, + max_new_tokens: int = 32, + **kwargs): + local_rank = dist.get_rank() + pre_rank = (local_rank - 1) % self.pipeline_parallel_stages + next_rank = (local_rank + 1) % self.pipeline_parallel_stages + + _input_ids = None + _past_key_values = None + bs = inputs.shape[0] + output_ids = inputs.clone() + for i in range(max_new_tokens): + if _input_ids is None: + _input_ids = inputs + + if local_rank == 0: + outputs = self(input_ids=_input_ids, inputs_embeds=None, + past_key_values=_past_key_values, use_cache=True) + else: + inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,), + device=f'xpu:{local_rank}', dtype=torch.float32) + dist.recv(inputs_embeds, src=pre_rank) + outputs = self(input_ids=None, inputs_embeds=inputs_embeds, + past_key_values=_past_key_values, use_cache=True) + + if local_rank == self.pipeline_parallel_stages - 1: + logits = outputs.logits + next_ids = torch.argmax(logits[:, -1:, :], dim=-1) + dist.broadcast(next_ids, src=local_rank) + else: + dist.send(outputs[0], dst=next_rank) + next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64) + dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1) + + _input_ids = next_ids + output_ids = torch.cat([output_ids, next_ids], dim=-1) + _past_key_values = outputs.past_key_values + return output_ids From 1ab38c17e7a1adbf27ac0d7ea3954818470ae821 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 00:31:10 +0800 Subject: [PATCH 2/5] fix code style --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 3859e0e4cfd..dcef0104746 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -75,7 +75,7 @@ def init_pipeline_parallel(): def pipeline_parallel(model, pipeline_parallel_stages): slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \ pipeline_parallel_stages - + local_rank = dist.get_rank() layer_start = slice_size * local_rank layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start) @@ -141,7 +141,7 @@ def pipeline_parallel_generate(self, local_rank = dist.get_rank() pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages - + _input_ids = None _past_key_values = None bs = inputs.shape[0] @@ -168,7 +168,7 @@ def pipeline_parallel_generate(self, dist.send(outputs[0], dst=next_rank) next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64) dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1) - + _input_ids = next_ids output_ids = torch.cat([output_ids, next_ids], dim=-1) _past_key_values = outputs.past_key_values From 32495bef38bf87b94387cb86912f19ec03991626 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 00:51:18 +0800 Subject: [PATCH 3/5] fix ut --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index dcef0104746..a81f0abc979 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -111,7 +111,7 @@ def generate( streamer: Optional["BaseStreamer"] = None, **kwargs, ): - if self.pipeline_parallel_stages > 1: + if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1: if generation_config is not None and generation_config.max_new_tokens is not None: max_new_tokens = generation_config.max_new_tokens else: From d114b8ffcdf204478592fe71d3d0629ec1dbf000 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 01:48:27 +0800 Subject: [PATCH 4/5] add 1st and rest time --- .../GPU/Pipeline-Parallel-Inference/README.md | 1 + .../Pipeline-Parallel-Inference/generate.py | 1 + .../transformers/pipeline_parallel.py | 22 ++++++++++++++++++- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md index 42e72cc5bea..381546d59b5 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -37,6 +37,7 @@ bash run_llama2_13b_arc_2_card.sh ##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) ```log Inference time: xxxx s +First token cost xxxx s and rest tokens cost average xxxx s -------------------- Prompt -------------------- Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun -------------------- Output -------------------- diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py index 7e7736d9b8d..5104c7010f0 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py @@ -69,6 +69,7 @@ if local_rank == args.gpu_num - 1: output_str = tokenizer.decode(output[0], skip_special_tokens=True) print(f'Inference time: {end-st} s') + print(f"First token cost {model.first_token_time:.4f} s and rest tokens cost average {model.rest_cost_mean:.4f} s") print('-'*20, 'Prompt', '-'*20) print(args.prompt) print('-'*20, 'Output', '-'*20) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index a81f0abc979..d750cc1b7e0 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -21,6 +21,8 @@ from torch import nn import torch.distributed as dist import os +import time +import numpy as np from typing import Callable, List, Optional from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList @@ -142,14 +144,23 @@ def pipeline_parallel_generate(self, pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages + self.first_token_time = 0 + self.next_token_time = [] + _input_ids = None _past_key_values = None bs = inputs.shape[0] output_ids = inputs.clone() - for i in range(max_new_tokens): + + step = 0 + while True: + if step >= max_new_tokens: + break + if _input_ids is None: _input_ids = inputs + tic = time.time() if local_rank == 0: outputs = self(input_ids=_input_ids, inputs_embeds=None, past_key_values=_past_key_values, use_cache=True) @@ -172,4 +183,13 @@ def pipeline_parallel_generate(self, _input_ids = next_ids output_ids = torch.cat([output_ids, next_ids], dim=-1) _past_key_values = outputs.past_key_values + toc = time.time() + if step == 0: + self.first_token_time = toc - tic + else: + self.next_token_time.append(toc - tic) + step += 1 + if self.device.type == 'xpu': + torch.xpu.synchronize() + self.rest_cost_mean = np.mean(self.next_token_time) return output_ids From ef8c138b4ba299677a02e11233589e1e69f609c2 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 13 Jun 2024 02:20:21 +0800 Subject: [PATCH 5/5] add verified model list --- python/llm/example/GPU/Pipeline-Parallel-Inference/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md index 381546d59b5..c1ffdd96b1e 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -5,6 +5,11 @@ This example demonstrates how to run IPEX-LLM optimized low-bit model vertically ## Requirements To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. +## Verified Models +- [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +- [Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) +- [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) + ## Example: Run pipeline parallel inference on multiple GPUs ### 0. Prerequisites