diff --git a/src/instructlab/model/backends/backends.py b/src/instructlab/model/backends/backends.py index d20eb73002..b96868851a 100644 --- a/src/instructlab/model/backends/backends.py +++ b/src/instructlab/model/backends/backends.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from time import monotonic, sleep, time +from time import monotonic, sleep from types import FrameType from typing import Optional, Tuple import abc @@ -26,9 +26,7 @@ import uvicorn # Local -from ...client import check_api_base from ...configuration import _serve as serve_config -from ...configuration import get_api_base from ...utils import split_hostport from .common import CHAT_TEMPLATE_AUTO, LLAMA_CPP, VLLM @@ -410,73 +408,6 @@ def wait_for_stable_vram_cuda(timeout: int) -> Tuple[bool, bool]: logger.debug("Could not free cuda cache: %s", e) -def ensure_server( - backend: str, - api_base: str, - http_client=None, - host="localhost", - port=8000, - background=True, - foreground_allowed=False, - server_process_func=None, - max_startup_attempts=None, -) -> Tuple[ - Optional[multiprocessing.Process], Optional[subprocess.Popen], Optional[str] -]: - """Checks if server is running, if not starts one as a subprocess. Returns the server process - and the URL where it's available.""" - - logger.info(f"Trying to connect to model server at {api_base}") - if check_api_base(api_base, http_client): - return (None, None, api_base) - port = free_tcp_ipv4_port(host) - logger.debug(f"Using available port {port} for temporary model serving.") - - host_port = f"{host}:{port}" - temp_api_base = get_api_base(host_port) - vllm_server_process = None - - if backend == VLLM: - # TODO: resolve how the hostname is getting passed around the class and this function - vllm_server_process = server_process_func(port, background) - logger.info("Starting a temporary vLLM server at %s", temp_api_base) - count = 0 - # Each call to check_api_base takes >2s + 2s sleep - # Default to 120 if not specified (~8 mins of wait time) - vllm_startup_max_attempts = max_startup_attempts or 120 - start_time_secs = time() - while count < vllm_startup_max_attempts: - count += 1 - # Check if the process is still alive - if vllm_server_process.poll(): - if foreground_allowed and background: - raise ServerException( - "vLLM failed to start. Retry with --enable-serving-output to learn more about the failure." - ) - raise ServerException("vLLM failed to start.") - logger.info( - "Waiting for the vLLM server to start at %s, this might take a moment... Attempt: %s/%s", - temp_api_base, - count, - vllm_startup_max_attempts, - ) - if check_api_base(temp_api_base, http_client): - logger.info("vLLM engine successfully started at %s", temp_api_base) - break - if count == vllm_startup_max_attempts: - logger.info( - "Gave up waiting for vLLM server to start at %s after %s attempts", - temp_api_base, - vllm_startup_max_attempts, - ) - duration = round(time() - start_time_secs, 1) - shutdown_process(vllm_server_process, 20) - # pylint: disable=raise-missing-from - raise ServerException(f"vLLM failed to start up in {duration} seconds") - sleep(2) - return (None, vllm_server_process, temp_api_base) - - def free_tcp_ipv4_port(host: str) -> int: """Ask the OS for a random, ephemeral, and bindable TCP/IPv4 port diff --git a/src/instructlab/model/backends/vllm.py b/src/instructlab/model/backends/vllm.py index e0ccf40b7b..e207704dc0 100644 --- a/src/instructlab/model/backends/vllm.py +++ b/src/instructlab/model/backends/vllm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from typing import Optional, Tuple import json import logging import os @@ -15,12 +16,13 @@ import httpx # Local +from ...client import check_api_base from ...configuration import get_api_base from .backends import ( BackendServer, Closeable, ServerException, - ensure_server, + free_tcp_ipv4_port, safe_close_all, shutdown_process, ) @@ -93,6 +95,61 @@ def create_server_process(self, port: int, background: bool) -> subprocess.Popen self.register_resources(files) return server_process + def _ensure_server( + self, + http_client=None, + background=True, + foreground_allowed=False, + ) -> Tuple[Optional[subprocess.Popen], Optional[str]]: + """Checks if server is running, if not starts one as a subprocess. Returns the server process + and the URL where it's available.""" + + logger.info(f"Trying to connect to model server at {self.api_base}") + if check_api_base(self.api_base, http_client): + return (None, self.api_base) + port = free_tcp_ipv4_port(self.host) + logger.debug(f"Using available port {port} for temporary model serving.") + + host_port = f"{self.host}:{port}" + temp_api_base = get_api_base(host_port) + vllm_server_process = self.create_server_process(port, background) + logger.info("Starting a temporary vLLM server at %s", temp_api_base) + count = 0 + # Each call to check_api_base takes >2s + 2s sleep + # Default to 120 if not specified (~8 mins of wait time) + vllm_startup_max_attempts = self.max_startup_attempts or 120 + start_time_secs = time.time() + while count < vllm_startup_max_attempts: + count += 1 + # Check if the process is still alive + if vllm_server_process.poll(): + if foreground_allowed and background: + raise ServerException( + "vLLM failed to start. Retry with --enable-serving-output to learn more about the failure." + ) + raise ServerException("vLLM failed to start.") + logger.info( + "Waiting for the vLLM server to start at %s, this might take a moment... Attempt: %s/%s", + temp_api_base, + count, + vllm_startup_max_attempts, + ) + if check_api_base(temp_api_base, http_client): + logger.info("vLLM engine successfully started at %s", temp_api_base) + break + if count == vllm_startup_max_attempts: + logger.info( + "Gave up waiting for vLLM server to start at %s after %s attempts", + temp_api_base, + vllm_startup_max_attempts, + ) + duration = round(time.time() - start_time_secs, 1) + shutdown_process(vllm_server_process, 20) + # pylint: disable=raise-missing-from + raise ServerException(f"vLLM failed to start up in {duration} seconds") + time.sleep(2) + return (vllm_server_process, temp_api_base) + def run_detached( self, http_client: httpx.Client | None = None, @@ -102,16 +159,10 @@ def run_detached( ) -> str: for i in range(max_startup_retries + 1): try: - _, vllm_server_process, api_base = ensure_server( - backend=VLLM, - api_base=self.api_base, + vllm_server_process, api_base = self._ensure_server( http_client=http_client, - host=self.host, - port=self.port, background=background, foreground_allowed=foreground_allowed, - server_process_func=self.create_server_process, - max_startup_attempts=self.max_startup_attempts, ) self.process = vllm_server_process or self.process self.api_base = api_base or self.api_base diff --git a/tests/common.py b/tests/common.py index 9d8fd677ae..755c1c9114 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,7 +32,7 @@ def setup_gpus_config(section_path="serve", gpus=None, tps=None, vllm_args=lambd return _CFG_FILE_NAME -@mock.patch("instructlab.model.backends.backends.check_api_base", return_value=False) +@mock.patch("instructlab.model.backends.vllm.check_api_base", return_value=False) # ^ mimic server *not* running already @mock.patch( "instructlab.model.backends.backends.determine_backend",