Skip to content

Commit

Permalink
Add tensor parallel for vLLM (#10879)
Browse files Browse the repository at this point in the history
* initial

* test initial tp

* initial sup

* fix format

* fix

* fix
  • Loading branch information
gc-fu authored Apr 26, 2024
1 parent d058f2b commit 990535b
Show file tree
Hide file tree
Showing 4 changed files with 507 additions and 10 deletions.
10 changes: 10 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def is_linear_module(module):
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
Expand All @@ -125,6 +129,12 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
tp_size = get_tensor_model_parallel_world_size()
if isinstance(module, RowParallelLinear) and tp_size >= 2:
mp_group = get_tensor_model_parallel_group()
in_features = module.input_size_per_partition
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
out_features = module.output_size_per_partition
else:
result = False
elif is_gptq_linear(module):
Expand Down
31 changes: 25 additions & 6 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@
from ipex_llm.utils.common import invalidInputError
import os
import torch
import torch.distributed
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from operator import mul
from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
get_ipex_version
from ipex_llm.transformers.convert import is_deepspeed_available, is_vllm_available

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -702,8 +704,14 @@ def forward(self, x: torch.Tensor):
torch.xpu.empty_cache()
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
# FIXME: the user may install both vllm and deepspeed
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
result += self.bias
else:
Expand All @@ -729,6 +737,7 @@ def forward(self, x: torch.Tensor):
result = result.view(new_shape)
# allreduce to combine partial results and add bias if necessary
if self.mp_group is not None:
# TODO: implement for CPU logic for vLLM tp
# deepspeed distibuted mode
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
Expand Down Expand Up @@ -780,8 +789,13 @@ def forward(self, x: torch.Tensor):
self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
return result
else:
if self.in_len == 4096 and self.weight_type != 3 or \
Expand Down Expand Up @@ -817,8 +831,13 @@ def forward(self, x: torch.Tensor):
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
if is_deepspeed_available():
from deepspeed import comm as dist
dist.inference_all_reduce(result, group=self.mp_group)
elif is_vllm_available():
torch.distributed.all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
result += self.bias

Expand Down
10 changes: 6 additions & 4 deletions python/llm/src/ipex_llm/vllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def from_engine_args(
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
# from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class_async
executor_class = get_gpu_executor_class_async(load_in_low_bit)
else:
invalidInputError(parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
Expand Down Expand Up @@ -130,8 +131,9 @@ def from_engine_args(
# Initialize the cluster and specify the executor class.
if parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
# from vllm.executor.ray_gpu_executor import RayGPUExecutor
from ipex_llm.vllm.ipex_llm_gpu_executor import get_gpu_executor_class
executor_class = get_gpu_executor_class(load_in_low_bit)
else:
invalidInputError(parallel_config.world_size == 1,
"Ray is required if parallel_config.world_size > 1.")
Expand Down
Loading

0 comments on commit 990535b

Please sign in to comment.