Skip to content

Commit

Permalink
fix worker startup
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser committed Jan 13, 2025
1 parent 1457272 commit 69ae94b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 38 deletions.
13 changes: 6 additions & 7 deletions vllm/executor/multiproc_spyre_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.executor.spyre_executor import SpyreExecutor, create_worker
from vllm.executor.spyre_executor import SpyreExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -123,12 +123,11 @@ def _init_executor(self) -> None:
worker = ProcessWorkerWrapper(
result_handler,
partial(
create_worker,
**self._get_create_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
))
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
Expand Down
33 changes: 3 additions & 30 deletions vllm/executor/spyre_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
from typing import Any, Dict, List, Optional, Set, Tuple

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
Expand All @@ -8,7 +8,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

Expand Down Expand Up @@ -60,41 +60,14 @@ def _get_worker_kwargs(
or (rank % self.parallel_config.tensor_parallel_size == 0),
)

def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
worker_module_name = "vllm.worker.spyre_worker"
worker_class_name = "SpyreWorker"
return (worker_module_name, worker_class_name, worker_class_fn)

def _get_create_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict:

worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)

(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
worker_kwargs.update(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
return worker_kwargs

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):

wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)

assert self.distributed_init_method is not None

wrapper.init_worker(**self._get_create_worker_kwargs(
wrapper.init_worker(**self._get_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
Expand Down
3 changes: 2 additions & 1 deletion vllm/platforms/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
raise NotImplementedError
else:
parallel_config.worker_cls = "vllm.worker.worker.SpyreWorker"
parallel_config.worker_cls = \
"vllm.worker.spyre_worker.SpyreWorker"

0 comments on commit 69ae94b

Please sign in to comment.