From 53409f21c24411f066f476756f5b99574655cc31 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Thu, 19 Dec 2024 19:26:35 +0100 Subject: [PATCH] update server --- pyproject.toml | 9 +- python/text_utils/api/cli.py | 116 ++++++-------- python/text_utils/api/server.py | 264 +++++++++++++++++--------------- python/text_utils/logging.py | 23 +-- 4 files changed, 205 insertions(+), 207 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d2f9501..eac3c02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,12 @@ dependencies = [ "numpy>=1.26", "pyyaml>=6.0", "tqdm>=4.66", - "tensorboard>=2.16.0", - "flask>=3.0", + "tensorboard>=2.16", + "fastapi>=0.115", + "uvicorn>=0.34", "requests>=2.31", - "termcolor>=2.4.0", - "grammar-utils>=0.1.0" + "termcolor>=2.4", + "grammar-utils>=0.1" ] [project.scripts] diff --git a/python/text_utils/api/cli.py b/python/text_utils/api/cli.py index bfe8e74..8c20b5b 100644 --- a/python/text_utils/api/cli.py +++ b/python/text_utils/api/cli.py @@ -1,9 +1,11 @@ import argparse -from io import TextIOWrapper +import logging import sys import time import warnings -from typing import Iterator, Type, Any +from io import TextIOWrapper +from typing import Any, Iterator, Type + try: import readline # noqa except ImportError: @@ -16,7 +18,7 @@ from text_utils.api.server import TextProcessingServer from text_utils.api.table import generate_report, generate_table from text_utils.api.utils import ProgressIterator -from text_utils.logging import setup_logging, disable_logging +from text_utils.logging import disable_logging, setup_logging class TextProcessingCli: @@ -24,23 +26,16 @@ class TextProcessingCli: text_processing_server_cls: Type[TextProcessingServer] @classmethod - def parser( - cls, - name: str, - description: str - ) -> argparse.ArgumentParser: + def parser(cls, name: str, description: str) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(name, description) model_group = parser.add_mutually_exclusive_group() default_model = cls.text_processor_cls.default_model() model_group.add_argument( "-m", "--model", - choices=[ - model.name for model in - cls.text_processor_cls.available_models() - ], + choices=[model.name for model in cls.text_processor_cls.available_models()], default=None if default_model is None else default_model.name, - help=f"Name of the model to use for {cls.text_processor_cls.task}" + help=f"Name of the model to use for {cls.text_processor_cls.task}", ) model_group.add_argument( "-e", @@ -48,55 +43,51 @@ def parser( type=str, default=None, help="Path to an experiment directory from which the model will be loaded " - "(use this when you trained your own model and want to use it)" + "(use this when you trained your own model and want to use it)", ) parser.add_argument( "--last", action="store_true", - help="Use last checkpoint instead of best, only works with experiments" + help="Use last checkpoint instead of best, only works with experiments", ) input_group = parser.add_mutually_exclusive_group() input_group.add_argument( - "-p", - "--process", - type=str, - default=None, - help="Text to process" + "-p", "--process", type=str, default=None, help="Text to process" ) input_group.add_argument( "-f", "--file", type=str, default=None, - help="Path to a text file which will be processed" + help="Path to a text file which will be processed", ) input_group.add_argument( "-i", "--interactive", action="store_true", default=None, - help="Start an interactive session where your command line input is processed" + help="Start an interactive session where your command line input is processed", ) parser.add_argument( "-o", "--out-path", type=str, default=None, - help="Path where processed text should be saved to" + help="Path where processed text should be saved to", ) parser.add_argument( "-d", "--device", type=str, nargs="+", - help="Specify one or more devices to use for inference, by default a single GPU is used if available" + help="Specify one or more devices to use for inference, by default a single GPU is used if available", ) parser.add_argument( "-n", "--num-threads", type=int, default=None, - help="Number of threads used for running the inference pipeline" + help="Number of threads used for running the inference pipeline", ) batch_limit_group = parser.add_mutually_exclusive_group() batch_limit_group.add_argument( @@ -105,7 +96,7 @@ def parser( type=int, default=16, help="Determines how many inputs will be processed at the same time, larger values should usually result " - "in faster processing but require more memory" + "in faster processing but require more memory", ) batch_limit_group.add_argument( "-t", @@ -113,7 +104,7 @@ def parser( type=int, default=None, help="Determines the maximum number of tokens processed at the same time, larger values should usually " - "result in faster processing but require more memory" + "result in faster processing but require more memory", ) parser.add_argument( "-u", @@ -121,65 +112,65 @@ def parser( action="store_true", help="Disable sorting of the inputs before processing (for a large number of inputs or large text files " "sorting the sequences beforehand leads to speed ups because it minimizes the amount of padding " - "needed within a batch of sequences)" + "needed within a batch of sequences)", ) parser.add_argument( "-l", "--list", action="store_true", - help="List all available models with short descriptions" + help="List all available models with short descriptions", ) parser.add_argument( "-v", "--version", action="store_true", - help=f"Print name and version of the underlying {cls.text_processor_cls.task} library" + help=f"Print name and version of the underlying {cls.text_processor_cls.task} library", ) parser.add_argument( "--force-download", action="store_true", - help="Download the model again even if it already was downloaded" + help="Download the model again even if it already was downloaded", ) parser.add_argument( "--download-dir", type=str, default=None, - help="Directory the model will be downloaded to (as zip file)" + help="Directory the model will be downloaded to (as zip file)", ) parser.add_argument( "--cache-dir", type=str, default=None, - help="Directory the downloaded model will be extracted to" + help="Directory the downloaded model will be extracted to", ) parser.add_argument( "--server", type=str, default=None, - help=f"Path to a yaml config file to run a {cls.text_processor_cls.task} server" + help=f"Path to a yaml config file to run a {cls.text_processor_cls.task} server", ) parser.add_argument( "--report", action="store_true", - help="Print a runtime report (ignoring startup time) at the end of the processing" + help="Print a runtime report (ignoring startup time) at the end of the processing", ) parser.add_argument( "--progress", action="store_true", - help="Show a progress bar while processing" + help="Show a progress bar while processing", ) parser.add_argument( "--log-level", type=str, - choices=["info", "debug", "warning", "error", "critical"], + choices=list(logging._nameToLevel), default=None, - help="Sets the logging level for the underlying loggers" + help="Sets the logging level for the underlying loggers", ) parser.add_argument( "--profile", type=str, default=None, - help="Run CLI with cProfile profiler on and output stats to this file" + help="Run CLI with cProfile profiler on and output stats to this file", ) return parser @@ -191,20 +182,16 @@ def version(self) -> str: def _run_with_profiling(self, file: str) -> None: import cProfile + cProfile.runctx("self.run()", globals(), locals(), file) def process_iter( - self, - processor: TextProcessor, - iter: Iterator[str] + self, processor: TextProcessor, iter: Iterator[str] ) -> Iterator[Any]: raise NotImplementedError def setup(self) -> TextProcessor: - device = self.args.device or ( - "cuda" if torch.cuda.is_available() - else "cpu" - ) + device = self.args.device or ("cuda" if torch.cuda.is_available() else "cpu") if self.args.experiment: cor = self.text_processor_cls.from_experiment( experiment_dir=self.args.experiment, @@ -217,7 +204,7 @@ def setup(self) -> TextProcessor: device=device, download_dir=self.args.download_dir, cache_dir=self.args.cache_dir, - force_download=self.args.force_download + force_download=self.args.force_download, ) return cor @@ -241,17 +228,20 @@ def run(self) -> None: table = generate_table( headers=[["Model", "Description", "Tags"]], data=[ - [model.name, model.description, ", ".join( - str(tag) for tag in model.tags)] + [ + model.name, + model.description, + ", ".join(str(tag) for tag in model.tags), + ] for model in self.text_processor_cls.available_models() ], alignments=["left", "left", "left"], - max_column_width=80 + max_column_width=80, ) print(table) return elif self.args.server is not None: - setup_logging((self.args.log_level or "INFO").upper()) + setup_logging(self.args.log_level or logging.INFO) self.text_processing_server_cls.from_config(self.args.server).run() return @@ -270,9 +260,7 @@ def run(self) -> None: start = time.perf_counter() if self.args.process is not None: self.args.progress = False - for output in self.process_iter( - self.cor, iter([self.args.process]) - ): + for output in self.process_iter(self.cor, iter([self.args.process])): print(output) elif self.args.file is not None: @@ -283,14 +271,8 @@ def run(self) -> None: assert isinstance(self.args.out_path, str) out = open(self.args.out_path, "w") - input_it = ( - line.rstrip("\r\n") - for line in open(self.args.file) - ) - sized_it = ProgressIterator( - input_it, - self.input_size - ) + input_it = (line.rstrip("\r\n") for line in open(self.args.file)) + sized_it = ProgressIterator(input_it, self.input_size) for output in self.process_iter(self.cor, sized_it): out.write(output + "\n") @@ -330,14 +312,8 @@ def run(self) -> None: try: # correct lines from stdin as they come - input_it = ( - line.rstrip("\r\n") - for line in sys.stdin - ) - sized_it = ProgressIterator( - input_it, - self.input_size - ) + input_it = (line.rstrip("\r\n") for line in sys.stdin) + sized_it = ProgressIterator(input_it, self.input_size) for output in self.process_iter(self.cor, sized_it): print(output, flush=self.args.unsorted) diff --git a/python/text_utils/api/server.py b/python/text_utils/api/server.py index efb8e1c..9b1a3b3 100644 --- a/python/text_utils/api/server.py +++ b/python/text_utils/api/server.py @@ -1,186 +1,204 @@ -import os -from contextlib import contextmanager +import asyncio import logging -from threading import Lock -from typing import Dict, Any, Type, Union, Generator +from contextlib import asynccontextmanager +from typing import Any, Type -import yaml import torch -from flask import Flask, Response, cli, jsonify +import uvicorn +import yaml +from fastapi import FastAPI, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send -from text_utils.api.processor import TextProcessor, ModelInfo -from text_utils.api.utils import gpu_info, cpu_info -from text_utils.logging import get_logger from text_utils import configuration +from text_utils.api.processor import ModelInfo, TextProcessor +from text_utils.api.utils import cpu_info, gpu_info +from text_utils.logging import get_logger + + +class RequestCancelledMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + print(f"In {scope['type']} scope") + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # Let's make a shared queue for the request messages + queue = asyncio.Queue() + + async def message_poller(sentinel: object, handler_task: asyncio.Task): + nonlocal queue + while True: + message = await receive() + if message["type"] == "http.disconnect": + print("Canceling handler task") + handler_task.cancel() + return sentinel # Break the loop + + # Puts the message in the queue + await queue.put(message) + + sentinel = object() + handler_task = asyncio.create_task(self.app(scope, queue.get, send)) # type: ignore + asyncio.create_task(message_poller(sentinel, handler_task)) + + try: + return await handler_task + except asyncio.CancelledError: + print("Cancelling request due to disconnect") class Error: - def __init__(self, msg: str, status: int): - self.msg = msg - self.status = status + def __init__(self, error: str, status_code: int): + self.error = error + self.status_code = status_code - def to_response(self) -> Response: - return Response(self.msg, status=self.status) + def to_response(self) -> JSONResponse: + return JSONResponse({"error": self.error}, self.status_code) class TextProcessingServer: text_processor_cls: Type[TextProcessor] @classmethod - def from_config(cls, path: str) -> "TextProcessingServer": + def from_config( + cls, path: str, log_level: str | int | None = None + ) -> "TextProcessingServer": config = configuration.load_config(path) - return cls(config) + return cls(config, log_level) - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: dict[str, Any], log_level: str | int | None = None): self.config = config - self.logger = get_logger( - f"{self.text_processor_cls.task.upper()} SERVER" - ) - self.logger.info(f"loaded server config:\n{yaml.dump(config)}") + self.logger = get_logger(f"{self.text_processor_cls.task} server", log_level) + self.logger.info(f"Loaded server config:\n{yaml.dump(config)}") self.port = int(self.config.get("port", 40000)) - # disable flask startup message and set flask mode to development - cli.show_server_banner = lambda *_: None - os.environ["FLASK_DEBUG"] = "development" - self.server = Flask(__name__) - max_content_length = int( - float(config.get("max_content_length", 1000.0)) * 1000.0 - ) - self.server.config["MAX_CONTENT_LENGTH"] = max_content_length - self.max_models_per_gpu = max(1, config.get("max_models_per_gpu", 3)) - self.allow_origin = config.get("allow_origin", "*") - self.timeout = float(config.get("timeout", 10.0)) - logging.getLogger("werkzeug").disabled = True - self.num_gpus = torch.cuda.device_count() - assert "models" in config and len(config["models"]) > 0, \ - "expected at least one model to be specified in the server config" + self.server = FastAPI() + self.server.add_middleware( + CORSMiddleware, + allow_origins=[config.get("allow_origin", "*")], + allow_credentials=True, + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["*"], + ) + self.server.add_middleware(RequestCancelledMiddleware) - @self.server.after_request - def _after_request(response: Response) -> Response: - response.headers.add( - "Access-Control-Allow-Origin", - self.allow_origin - ) - response.headers.add( - "Access-Control-Allow-Headers", - "*" - ) - response.headers.add( - "Access-Control-Allow-Private-Network", - "true" - ) - return response + self.max_models_per_gpu = max(1, config.get("max_models_per_gpu", 1)) + self.timeout = config.get("timeout", 10.0) + self.num_gpus = torch.cuda.device_count() - @self.server.route("/info") - def _info() -> Response: - response = jsonify({ - "gpu": [gpu_info(i) for i in range(self.num_gpus)], - "cpu": cpu_info(), - "timeout": self.timeout, - }) - return response - - self.text_processors: list[TextProcessor] = [] - self.name_to_idx = {} - self.lock = Lock() - - model_infos = [] - assert "models" in config, "expected models in server config" - for i, cfg in enumerate(config["models"]): - if "device" in cfg: - device = cfg["device"] + assert ( + "models" in config and len(config["models"]) > 0 + ), "Expected at least one model to be specified in the server config" + + self.text_processors: dict[str, TextProcessor] = {} + self.model_infos = {} + self.model_cfgs = {} + assert "models" in config, "Expected models in server config" + for name, model_cfg in config["models"].items(): + if "device" in model_cfg: + device = model_cfg["device"] elif self.num_gpus > 0: device = f"cuda:{len(self.text_processors) % self.num_gpus}" else: device = "cpu" - if "name" in cfg: - model_name = cfg["name"] + if "name" in model_cfg: + model_name = model_cfg["name"] model_info = next( filter( lambda m: m.name == model_name, - self.text_processor_cls.available_models() + self.text_processor_cls.available_models(), ), - None + None, ) if model_info is None: raise RuntimeError( - f"model {model_name} not found in available models" + f"Model {model_name} not found in available models" ) self.logger.info( - f"loading pretrained model {model_info.name} for task " + f"Loading pretrained model {model_info.name} for task " f"{self.text_processor_cls.task} onto device {device}" ) text_processor = self.text_processor_cls.from_pretrained( - model_name, - device + model_name, device ) model_info.tags.append("src::pretrained") - elif "path" in cfg: - path = cfg["path"] + elif "path" in model_cfg: + path = model_cfg["path"] self.logger.info( - f"loading model for task {self.text_processor_cls.task} " + f"Loading model for task {self.text_processor_cls.task} " f"from experiment {path} onto device {device}" ) - text_processor = self.text_processor_cls.from_experiment( - path, - device - ) + text_processor = self.text_processor_cls.from_experiment(path, device) model_info = ModelInfo( name=text_processor.name, - description="loaded from custom experiment", - tags=["src::experiment"] + description="Loaded from custom experiment", + tags=["src::experiment"], ) else: - raise RuntimeError( - "expected either name or path in model config" - ) + raise RuntimeError("Expected either name or path in model config") # handle the case when two models have the same name - if model_info.name in self.text_processors: - raise RuntimeError( - f"got multiple models with name '{model_info.name}', " - f"second one at position {i + 1}" - ) + if name in self.text_processors: + raise RuntimeError(f"Got multiple models with name '{name}'") - model_infos.append(model_info) - self.text_processors.append(text_processor) - self.name_to_idx[model_info.name] = i + self.model_infos[name] = model_info + self.model_cfgs[name] = model_cfg + self.text_processors[name] = text_processor - @self.server.route("/models") - def _models() -> Response: - response = jsonify({ - "task": self.text_processor_cls.task, - "models": [ - info._asdict() - for info in model_infos - ] - }) - return response - - @contextmanager - def text_processor(self, model_name: str) -> Generator[Union[TextProcessor, Error], None, None]: - if model_name not in self.name_to_idx: - yield Error(f"model {model_name} does not exist", 404) - return + self.lock = asyncio.Lock() - acquired = self.lock.acquire(timeout=self.timeout) - if not acquired: - yield Error(f"failed to reserve model within {self.timeout}s", 503) + @self.server.get("/info") + async def info() -> dict[str, Any]: + return { + "gpu": [gpu_info(i) for i in range(self.num_gpus)], + "cpu": cpu_info(), + "timeout": self.timeout, + } + + @self.server.get("/models") + async def models() -> dict[str, Any]: + return { + "task": self.text_processor_cls.task, + "models": { + name: info._asdict() for name, info in self.model_infos.items() + }, + } + + @asynccontextmanager + async def get_text_processor(self, name: str): + if name not in self.text_processors: + yield Error(f"Model {name} does not exist", status.HTTP_404_NOT_FOUND) return try: - yield self.text_processors[self.name_to_idx[model_name]] + await asyncio.wait_for(self.lock.acquire(), timeout=self.timeout) + + yield self.text_processors[name] + + except asyncio.TimeoutError: + yield Error( + f"Failed to acquire lock within {self.timeout:.2f}s", + status.HTTP_503_SERVICE_UNAVAILABLE, + ) + finally: - self.lock.release() + if self.lock.locked(): + self.lock.release() def run(self): - self.server.run( - "0.0.0.0", - self.port, - debug=False, - use_reloader=False + uvicorn.run( + self.server, + host="0.0.0.0", + port=self.port, + log_level=self.logger.level, + limit_concurrency=32, ) diff --git a/python/text_utils/logging.py b/python/text_utils/logging.py index b76075b..a8bd86a 100644 --- a/python/text_utils/logging.py +++ b/python/text_utils/logging.py @@ -2,11 +2,16 @@ LOG_FORMAT = "[%(asctime)s] {%(name)s - %(levelname)s} %(message)s" -__all__ = ["setup_logging", "add_file_log", "get_logger", - "eta_minutes_message", "eta_seconds_message"] +__all__ = [ + "setup_logging", + "add_file_log", + "get_logger", + "eta_minutes_message", + "eta_seconds_message", +] -def setup_logging(level: int | str = logging.INFO) -> None: +def setup_logging(level: str | int | None = None) -> None: """ Sets up logging with a custom log format and level. @@ -14,10 +19,7 @@ def setup_logging(level: int | str = logging.INFO) -> None: :param level: log level :return: None """ - logging.basicConfig( - format=LOG_FORMAT, - level=logging.getLevelName(level) - ) + logging.basicConfig(format=LOG_FORMAT, level=level) def disable_logging() -> None: @@ -44,7 +46,7 @@ def add_file_log(logger: logging.Logger, log_file: str) -> None: logger.addHandler(file_handler) -def get_logger(name: str, level: int | None = None) -> logging.Logger: +def get_logger(name: str, level: str | int | None = None) -> logging.Logger: """ Get a logger that writes to stderr. @@ -60,8 +62,9 @@ def get_logger(name: str, level: int | None = None) -> logging.Logger: stderr_handler.setFormatter(logging.Formatter(LOG_FORMAT)) if not logger.hasHandlers(): logger.addHandler(stderr_handler) - if level is not None: - logger.setLevel(level) + + logger.setLevel(level or logging.root.level) + return logger