Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pipeline parallel multi-stage implementation #11286

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 20 additions & 62 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,90 +5,48 @@ This example demonstrates how to run IPEX-LLM optimized low-bit model vertically
## Requirements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document which models have been verified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document which models have been verified.

Have updated verified model list.

To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.

> [!NOTE]
> To run IPEX-LLM on multiple Intel GPUs in pipeline parallel fashion, you will need to install **Intel® oneAPI Base Toolkit 2024.1**, which could be done through an offline installer:
> ```bash
> wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/fdc7a2bc-b7a8-47eb-8876-de6201297144/l_BaseKit_p_2024.1.0.596_offline.sh
>
> sudo sh ./l_BaseKit_p_2024.1.0.596_offline.sh
> ```
## Verified Models
- [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- [Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
- [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)

## Example: Run pipeline parallel inference on multiple GPUs

### 0. Prerequisites

Please visit the [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html), follow [Install Intel GPU Driver](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-intel-gpu-driver) and [Install oneAPI](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-oneapi) to install GPU driver and Intel® oneAPI Base Toolkit 2024.0.

### 1. Installation

```bash
conda create -n llm python=3.11
conda activate llm

# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
```

### 2. Configures OneAPI environment variables

```bash
source /opt/intel/oneapi/setvars.sh
```
### 2. Run pipeline parallel inference on multiple GPUs

> [!NOTE]
> Please make sure you configure the environment variables for **Intel® oneAPI Base Toolkit's version == 2024.1.**.
For optimal performance, it is recommended to set several environment variables. We provide example usage as following:

### 3 Runtime Configurations

For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.

<details>

<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
- Run Llama-2-13b-chat-hf on two Intel Arc A770

```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export SYCL_CACHE_PERSISTENT=1
bash run_llama2_13b_arc_2_card.sh
```

</details>

<details>

<summary>For Intel Data Center GPU Max Series</summary>

```bash
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export SYCL_CACHE_PERSISTENT=1
export ENABLE_SDP_FUSION=1
```
> [!NOTE]
> Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`.
</details>

### 4. Running examples
```
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT --gpu-num GPU_NUM
```

Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
- `--gpu-num GPU_NUM`: argument defining the number of GPU to use. It is default to be `2`.
> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.

#### Sample Output
##### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
```log
Inference time: xxxx s
First token cost xxxx s and rest tokens cost average xxxx s
-------------------- Prompt --------------------
<s>[INST] <<SYS>>

<</SYS>>

What is AI? [/INST]
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
-------------------- Output --------------------
[INST] <<SYS>>

<</SYS>>
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was always asking her parents to take her on trips, but they were always too busy or too tired.

What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
One day, the little girl
```
51 changes: 14 additions & 37 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,18 @@
import time
import argparse

from ipex_llm.transformers import AutoModelForCausalLM
from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
from transformers import AutoTokenizer

# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
DEFAULT_SYSTEM_PROMPT = """\
"""

def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
init_pipeline_parallel()


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-13b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
Expand All @@ -66,35 +50,28 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
local_rank = torch.distributed.get_rank()

# Generate predicted tokens
with torch.inference_mode():
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(f'xpu:{local_rank}')
# ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)

# start inference
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with IPEX-LLM INT4 optimizations
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)
if local_rank == args.gpu_num - 1:
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print(f"First token cost {model.first_token_time:.4f} s and rest tokens cost average {model.rest_cost_mean:.4f} s")
print('-'*20, 'Prompt', '-'*20)
print(args.prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0

NUM_GPUS=2 # number of used GPU
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'meta-llama/Llama-2-13b-chat-hf' --gpu-num $NUM_GPUS
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
AutoModelForTokenClassification
from .modelling_bigdl import *
from .pipeline_parallel import init_pipeline_parallel
32 changes: 9 additions & 23 deletions python/llm/src/ipex_llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,28 +95,6 @@ def save_low_bit(self, *args, **kwargs):
self.to(origin_device)


def pipeline_parallel(model, pipeline_parallel_stages):
model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']

device_map = {}
split_len = len(model_layers) // pipeline_parallel_stages
for i in range(pipeline_parallel_stages):
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * i: split_len * (i + 1)]})
if i == pipeline_parallel_stages - 1:
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * (i + 1):]})

from accelerate import dispatch_model
model = dispatch_model(
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
)
return model


def _load_pre():
from transformers import GPTJModel
from ipex_llm.transformers.models.gptj import gptj_model_new_init
Expand Down Expand Up @@ -377,8 +355,16 @@ def from_pretrained(cls,
invalidInputError(False,
f"Please do not set speculative=True"
f" when using pipeline_parallel_stages")
invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
"Please make sure you've called `init_pipeline_parallel()` "
"and world size is the same as `pipeline_parallel_stages`")
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages)

import types
# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
model)
torch.distributed.barrier()
if speculative:
from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values
Expand Down
Loading
Loading