Skip to content

Commit

Permalink
Refactor pipeline parallel device config (#11149)
Browse files Browse the repository at this point in the history
* refactor pipeline parallel device config

* meet comments

* update example

* add warnings and update code doc
  • Loading branch information
sgwhat authored May 28, 2024
1 parent 62b2d8a commit 33852bd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
23 changes: 2 additions & 21 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,8 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True)

model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']

device_map = {}
split_len = len(model_layers) // args.gpu_num
for i in range(args.gpu_num):
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * i: split_len * (i + 1)]})
if i == args.gpu_num - 1:
device_map.update({key: f'xpu:{i}' for key in model_layers[split_len * (i + 1): ]})

from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=None,
skip_keys=["past_key_value", "past_key_values"],
)
use_cache=True,
pipeline_parallel_stages=args.gpu_num)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Expand Down
33 changes: 33 additions & 0 deletions python/llm/src/ipex_llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ def save_low_bit(self, *args, **kwargs):
self.to(origin_device)


def pipeline_parallel(model, pipeline_parallel_stages):
model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']

device_map = {}
split_len = len(model_layers) // pipeline_parallel_stages
for i in range(pipeline_parallel_stages):
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * i: split_len * (i + 1)]})
if i == pipeline_parallel_stages - 1:
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * (i + 1):]})

from accelerate import dispatch_model
model = dispatch_model(
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
)
return model


def _load_pre():
from transformers import GPTJModel
from ipex_llm.transformers.models.gptj import gptj_model_new_init
Expand Down Expand Up @@ -157,6 +179,9 @@ def from_pretrained(cls,
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
Default to be False. If set to True, we will use sym_int8 for lm_head when
load_in_low_bit is sym_int4 or asym_int4.
:param pipeline_parallel_stages: int value, the number of GPUs allocated for
pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1
to run pipeline parallel inference on multiple GPUs.
:return: a model instance
"""
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
Expand Down Expand Up @@ -190,6 +215,7 @@ def from_pretrained(cls,
optimize_model = kwargs.pop("optimize_model", True)
user_quantization_config = kwargs.pop("quantization_config", None)
speculative = kwargs.pop("speculative", False)
pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1)
torch_dtype = kwargs.pop("torch_dtype", None)
embedding_qtype = kwargs.pop("embedding_qtype", None)

Expand Down Expand Up @@ -346,6 +372,13 @@ def from_pretrained(cls,
kwargs["embedding_qtype"] = embedding_qtype
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)

if pipeline_parallel_stages > 1:
if speculative:
invalidInputError(False,
f"Please do not set speculative=True"
f" when using pipeline_parallel_stages")
model = pipeline_parallel(model, pipeline_parallel_stages)

if speculative:
from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values
Expand Down

0 comments on commit 33852bd

Please sign in to comment.