Skip to content

Commit

Permalink
Squash njhill/ray-optional-todo2
Browse files Browse the repository at this point in the history
  • Loading branch information
tjohnson31415 committed May 13, 2024
1 parent 5a0f470 commit 9300e7e
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 27 deletions.
1 change: 0 additions & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
-r requirements-common.txt

# Dependencies for NVIDIA GPUs
ray >= 2.9
nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
Expand Down
3 changes: 1 addition & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Common dependencies
-r requirements-common.txt

# Dependencies for AMD GPUs
ray == 2.9.3
# No specific dependencies currently for AMD GPUs
20 changes: 16 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import sys
from shutil import which
from typing import Dict, List
from typing import Dict, List, Optional

import torch
from packaging.version import Version, parse
Expand Down Expand Up @@ -376,6 +376,20 @@ def _read_requirements(filename: str) -> List[str]:
return requirements


def get_extra_requirements() -> Optional[Dict[str, List[str]]]:
extras = {"tensorizer": ["tensorizer==2.9.0"]}
if _is_cuda():
extras["ray"] = ["ray>=2.9"]
elif _is_hip():
extras["ray"] = ["ray==2.9.3"]
elif _is_neuron() or _is_cpu():
pass
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCM or Neuron.")
return extras


ext_modules = []

if _is_cuda():
Expand Down Expand Up @@ -421,9 +435,7 @@ def _read_requirements(filename: str) -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer==2.9.0"],
},
extras_require=get_extra_requirements(),
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)
9 changes: 5 additions & 4 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("worker_use_ray", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
enforce_eager = False
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER":
enforce_eager = True
enforce_eager = backend_by_env_var == "FLASHINFER"

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand All @@ -46,7 +46,8 @@ def test_models(
vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager)
enforce_eager=enforce_eager,
worker_use_ray=worker_use_ray)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
3 changes: 3 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
@pytest.mark.parametrize("worker_use_ray", [False, True])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -35,6 +36,7 @@ def test_models(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -53,6 +55,7 @@ def test_models(
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
worker_use_ray=worker_use_ray,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
Expand Down
13 changes: 8 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,9 @@ class ParallelConfig:
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
worker_use_ray: Whether to use Ray for model workers. Will default to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
greater than 1 and Ray is installed.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
Expand All @@ -531,7 +531,7 @@ def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
worker_use_ray: Optional[bool] = None,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
Expand All @@ -548,8 +548,11 @@ def __init__(
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
if self.worker_use_ray is None:
from vllm.executor import ray_utils
ray_found = ray_utils.ray is not None
self.worker_use_ray = ray_found and self.world_size > 1

self._verify_args()

def _verify_args(self) -> None:
Expand Down
60 changes: 55 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class EngineArgs:
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
worker_use_ray: Optional[bool] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
Expand Down Expand Up @@ -220,10 +220,12 @@ def add_cli_args(
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU.')
parser.add_argument(
'--worker-use-ray',
action=BooleanOptionalAction,
default=None,
help='Use Ray for distributed serving, will default '
'to true when ray is installed and more than 1 GPU is used.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
Expand Down Expand Up @@ -647,3 +649,51 @@ def _engine_args_parser():
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
async_args_only=True)


class BooleanOptionalAction(argparse.Action):
"""This class is from python 3.9, included here to retain compatibility
with python 3.8. It allows boolean args with no explicit default, so
that we can honor explicit true/false but otherwise default based on some
other criteria.
See https://docs.python.org/3/library/argparse.html#action
It can be removed once 3.8 is no longer supported.
"""

def __init__(self,
option_strings,
dest,
default=None,
type=None,
choices=None,
required=False,
help=None,
metavar=None):

_option_strings = []
for option_string in option_strings:
_option_strings.append(option_string)

if option_string.startswith('--'):
option_string = '--no-' + option_string[2:]
_option_strings.append(option_string)

super().__init__(option_strings=_option_strings,
dest=dest,
nargs=0,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar)

def __call__(self, parser, namespace, values, option_string=None):
if option_string in self.option_strings:
setattr(namespace, self.dest,
not option_string.startswith('--no-'))

def format_usage(self):
return ' | '.join(self.option_strings)
6 changes: 4 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,11 @@ def from_engine_args(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif engine_config.parallel_config.world_size > 1:
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
Expand Down
6 changes: 4 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,11 @@ def from_engine_args(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif engine_config.parallel_config.world_size > 1:
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
executor_class = MultiprocessingGPUExecutor
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

Expand Down
140 changes: 140 additions & 0 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import asyncio
import os
from functools import partial
from typing import Any, Dict, Optional, Tuple

from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)

logger = init_logger(__name__)


class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""

def _init_executor(self) -> None:
assert (
not self.speculative_config
), "Speculative decoding not yet supported for MultiProcGPU backend."

# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size

# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
map(str, range(world_size))))

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

from torch.cuda import device_count
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count")

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())

if world_size == 1:
self.workers = []
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]

self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()

self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""

if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# Start the driver worker after all the ray workers.
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args,
**driver_kwargs)

# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]

def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")


class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

driver_executor = make_async(getattr(self.driver_worker, method))

# Run all the workers asynchronously.
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
for worker in self.workers
]

return await asyncio.gather(*coros)
Loading

0 comments on commit 9300e7e

Please sign in to comment.