Skip to content

Commit

Permalink
fix cache config and embed task name
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser committed Jan 14, 2025
1 parent 69ae94b commit deda681
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
14 changes: 14 additions & 0 deletions vllm/platforms/spyre.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Optional

from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
Expand All @@ -8,6 +10,7 @@

from .interface import Platform, PlatformEnum

logger = init_logger(__name__)

class SpyrePlatform(Platform):
_enum = PlatformEnum.SPYRE
Expand Down Expand Up @@ -39,3 +42,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = \
"vllm.worker.spyre_worker.SpyreWorker"

cache_config = vllm_config.cache_config
if cache_config:
# spyre needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Spyre.")
return False
4 changes: 2 additions & 2 deletions vllm/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()

if self.model_config.task == "embedding":
if self.model_config.task == "embed":
self.model_runner: SpyreModelRunner = SpyreEmbeddingModelRunner(
model_config, parallel_config, scheduler_config, device_config)
else:
Expand Down Expand Up @@ -147,7 +147,7 @@ def load_model(self):
(s["prompt_length"], s["new_tokens"], s["batch_size"])
for s in self.scheduler_config.spyre_warmup_shapes
]):
if self.model_config.task != "embedding":
if self.model_config.task != "embed":
# TODO: remove if spyre supports
# lower number of output tokens
assert num_decode_tokens >= 3, (
Expand Down

0 comments on commit deda681

Please sign in to comment.