Skip to content

Commit

Permalink
[core] clean up executor class hierarchy between v1 and v0 (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#12171)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 18, 2025
1 parent 02798ec commit 6d0e3d3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 798 deletions.
10 changes: 0 additions & 10 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
b = min([r[1] for r in results])
return a, b

def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
For V1 compatibility.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
Expand Down
87 changes: 58 additions & 29 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,92 @@
from abc import ABC, abstractmethod
from typing import Type

from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput


class Executor(ABC):
"""Abstract class for executors."""
class Executor(ExecutorBase):
"""
Abstract class for v1 executors, mainly define some methods for v1.
For methods shared by v0 and v1, define them in ExecutorBase"""

@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend)
if distributed_executor_backend is None:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if parallel_config.world_size > 1:
distributed_executor_backend = "mp"
else:
distributed_executor_backend = "uni"

if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# TODO: make v1 scheduling deterministic
# to support external launcher
executor_class = ExecutorWithExternalLauncher
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}")
return executor_class

@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError

@abstractmethod
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")

@abstractmethod
def determine_available_memory(self) -> int: # in bytes
raise NotImplementedError
output = self.collective_rpc("determine_available_memory")
# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(output)

@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
raise NotImplementedError
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]

@abstractmethod
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
raise NotImplementedError
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
return output[0]

@abstractmethod
def profile(self, is_start: bool = True):
raise NotImplementedError
self.collective_rpc("profile", args=(is_start, ))


class UniProcExecutor(UniProcExecutorV0, Executor):
pass


class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass

@abstractmethod
def shutdown(self):
pass

@abstractmethod
def check_health(self) -> None:
raise NotImplementedError
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
pass
50 changes: 3 additions & 47 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand All @@ -37,7 +35,7 @@

class MultiprocExecutor(Executor):

def __init__(self, vllm_config: VllmConfig) -> None:
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
Expand All @@ -55,9 +53,6 @@ def sigusr1_handler(signum, frame):

signal.signal(signal.SIGUSR1, sigusr1_handler)

self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config

self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
Expand All @@ -82,7 +77,8 @@ def sigusr1_handler(signum, frame):
# Create workers
self.workers: List[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
Expand All @@ -93,34 +89,6 @@ def sigusr1_handler(signum, frame):
for w in self.workers:
w.worker_response_mq.wait_until_ready()

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_available_memory(self) -> int:
"""
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
memory_sizes = self.collective_rpc("determine_available_memory")

# Since we use a shared centralized controller, we take the minimum
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
return min(memory_sizes)

def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
Expand Down Expand Up @@ -172,18 +140,6 @@ def collective_rpc(self,
# Re-raise any other exceptions
raise e

def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
args=(scheduler_output, ))[0]
return model_output

def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return

def _ensure_worker_termination(self):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
Expand Down
Loading

0 comments on commit 6d0e3d3

Please sign in to comment.