-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
llm-swarm backend integration for slurm clusters (#142)
- Loading branch information
1 parent
28c89c7
commit 2c58b76
Showing
7 changed files
with
151 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import asyncio | ||
import gc | ||
from logging import getLogger | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
from huggingface_hub import AsyncInferenceClient | ||
from llm_swarm import LLMSwarm | ||
from llm_swarm import LLMSwarmConfig as LLMSwarmCfg | ||
|
||
from ...task_utils import TEXT_GENERATION_TASKS | ||
from ..base import Backend | ||
from .config import LLMSwarmConfig | ||
|
||
# bachend logger | ||
LOGGER = getLogger("llm-swarm") | ||
|
||
|
||
class LLMSwarmBackend(Backend[LLMSwarmConfig]): | ||
NAME: str = "llm-swarm" | ||
|
||
def __init__(self, config: LLMSwarmConfig) -> None: | ||
super().__init__(config) | ||
self.validate_task() | ||
|
||
LOGGER.info("\t+ Downloading pretrained model") | ||
self.download_pretrained_model() | ||
LOGGER.info("\t+ Preparing generation config") | ||
self.prepare_generation_config() | ||
LOGGER.info("\t+ Loading pretrained model") | ||
self.load_model_from_pretrained() | ||
|
||
def validate_task(self) -> None: | ||
if self.config.task not in TEXT_GENERATION_TASKS: | ||
raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}") | ||
|
||
def load_model_from_pretrained(self) -> None: | ||
self.llm_swarm_config = LLMSwarmCfg( | ||
gpus=self.config.gpus, | ||
model=self.config.model, | ||
instances=self.config.instances, | ||
inference_engine=self.config.inference_engine, | ||
slurm_template_path=self.config.slurm_template_path, | ||
load_balancer_template_path=self.config.load_balancer_template_path, | ||
per_instance_max_parallel_requests=self.config.per_instance_max_parallel_requests, | ||
revision=self.config.hub_kwargs.get("revision", "main"), | ||
debug_endpoint=self.config.debug_endpoint, | ||
) | ||
self.llm_swarm = LLMSwarm(self.llm_swarm_config).__enter__() | ||
self.client = AsyncInferenceClient(self.llm_swarm.endpoint) | ||
|
||
def download_pretrained_model(self) -> None: | ||
with torch.device("meta"): | ||
self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs) | ||
|
||
def prepare_generation_config(self) -> None: | ||
self.generation_config.eos_token_id = -100 | ||
self.generation_config.pad_token_id = -100 | ||
model_cache_folder = f"models/{self.config.model}".replace("/", "--") | ||
model_cache_path = f"{self.config.volume}/{model_cache_folder}" | ||
snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}" | ||
snapshot_ref = open(snapshot_file, "r").read().strip() | ||
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" | ||
LOGGER.info("\t+ Saving new pretrained generation config") | ||
self.generation_config.save_pretrained(save_directory=model_snapshot_path) | ||
|
||
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
if "inputs" in inputs: | ||
return {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())} | ||
elif "input_ids" in inputs: | ||
return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} | ||
else: | ||
raise ValueError("inputs must contain either input_ids or inputs") | ||
|
||
async def single_client_call(self, prompt: str, kwargs: Dict[str, Any]) -> str: | ||
return await self.client.text_generation(prompt, max_new_tokens=kwargs.get("max_new_tokens", 1)) | ||
|
||
async def batch_client_call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: | ||
return await asyncio.gather(*(self.single_client_call(p, kwargs) for p in inputs["prompt"])) | ||
|
||
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: | ||
return asyncio.run(self.batch_client_call(inputs, kwargs)) | ||
|
||
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: | ||
return asyncio.run(self.batch_client_call(inputs, kwargs)) | ||
|
||
def clean(self) -> None: | ||
super().clean() | ||
|
||
if hasattr(self, "llm_swarm"): | ||
LOGGER.info("Cleaning up LLM Swarm") | ||
self.llm_swarm.__exit__(None, None, None) | ||
|
||
gc.collect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from ...import_utils import llm_swarm_version | ||
from ..config import BackendConfig | ||
|
||
|
||
@dataclass | ||
class LLMSwarmConfig(BackendConfig): | ||
name: str = "llm-swarm" | ||
version: Optional[str] = llm_swarm_version() | ||
_target_: str = "optimum_benchmark.backends.llm_swarm.backend.LLMSwarmBackend" | ||
|
||
# optimum benchmark specific | ||
no_weights: bool = False | ||
|
||
# llm-swarm specific | ||
gpus: int = 8 | ||
instances: int = 1 | ||
inference_engine: str = "tgi" | ||
volume: str = "/fsx/ilyas/.cache" | ||
per_instance_max_parallel_requests: int = 500 | ||
slurm_template_path: str = "/fsx/ilyas/swarm-templates/tgi_h100.template.slurm" | ||
load_balancer_template_path: str = "/fsx/ilyas/swarm-templates/nginx.template.conf" | ||
debug_endpoint: Optional[str] = None | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
# so that downloaded artifacts are stored in the same place | ||
self.hub_kwargs["cache_dir"] = self.volume |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters