Skip to content

Commit

Permalink
Add TGIS CLI integrated with upstream CLI
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <[email protected]>
Co-authored-by: Prashant Gupta <[email protected]>
  • Loading branch information
rafvasq and prashantgupta24 committed Jun 28, 2024
1 parent 095df75 commit 10d1d22
Show file tree
Hide file tree
Showing 9 changed files with 690 additions and 40 deletions.
4 changes: 2 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
On the server side, run one of the following commands:
vLLM OpenAI API server
python -m vllm.entrypoints.openai.api_server \
--model <your_model> --swap-space 16 \
vllm serve <your_model>
--swap-space 16 \
--disable-log-requests
(TGI backend)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)

```{argparse}
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:func: create_parser_for_docs
:prog: -m vllm.entrypoints.openai.api_server
```

Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,9 @@ def _read_requirements(filename: str) -> List[str]:
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
entry_points={
"console_scripts": [
"vllm=vllm.scripts:main",
],
},
)
Empty file added tests/tgis/test_hub.py
Empty file.
9 changes: 4 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
import warnings
from argparse import ArgumentParser
from contextlib import contextmanager
from typing import List

Expand Down Expand Up @@ -31,10 +32,7 @@ def __init__(self, cli_args: List[str], *, wait_url: str,
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
[
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*cli_args
],
["vllm", "serve", *cli_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down Expand Up @@ -74,7 +72,8 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:

cli_args = cli_args + ["--port", str(get_open_port())]

parser = make_arg_parser()
parser = ArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(cli_args)
self.host = str(args.host or 'localhost')
self.port = int(args.port)
Expand Down
81 changes: 53 additions & 28 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
import importlib
import inspect
Expand All @@ -8,7 +9,7 @@

import fastapi
import uvicorn
from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand All @@ -34,6 +35,9 @@

TIMEOUT_KEEP_ALIVE = 5 # seconds

logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
Expand Down Expand Up @@ -67,50 +71,35 @@ async def _force_log():
logger.info("gRPC server stopped")


app = fastapi.FastAPI(lifespan=lifespan)


def parse_args():
parser = make_arg_parser()
parser = add_tgis_args(parser)
parsed_args = parser.parse_args()
parsed_args = postprocess_tgis_args(parsed_args)
return parsed_args

router = APIRouter()

# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)

router.routes.append(route)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)


@app.get("/health")
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)


@app.get("/v1/models")
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
return JSONResponse(content=models.model_dump())


@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)


@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
Expand All @@ -126,7 +115,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump())


@app.post("/v1/completions")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
Expand All @@ -140,7 +129,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


@app.post("/v1/embeddings")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
Expand All @@ -151,8 +140,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


if __name__ == "__main__":
args = parse_args()
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path

app.add_middleware(
CORSMiddleware,
Expand All @@ -162,6 +153,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
allow_headers=args.allowed_headers,
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

if token := envs.VLLM_API_KEY or args.api_key:

@app.middleware("http")
Expand All @@ -187,6 +184,12 @@ async def authentication(request: Request, call_next):
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")

return app


def run_server(args, llm_engine=None):
app = build_app(args)

logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

Expand All @@ -195,6 +198,8 @@ async def authentication(request: Request, call_next):
else:
served_model_names = [args.model]

global engine, engine_args

engine_args = AsyncEngineArgs.from_cli_args(args)

# Enforce pixel values as image input type for vision language models
Expand All @@ -206,8 +211,9 @@ async def authentication(request: Request, call_next):
"Only --image-input-type 'pixel_values' is supported for serving "
"vision language models with the vLLM API server.")

engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

event_loop: Optional[asyncio.AbstractEventLoop]
try:
Expand All @@ -223,6 +229,11 @@ async def authentication(request: Request, call_next):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global async_llm_engine

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
Expand All @@ -247,3 +258,17 @@ async def authentication(request: Request, call_next):
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)


if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.

parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
parser = add_tgis_args(parser)
args = parser.parse_args()
args = postprocess_tgis_args(args)

run_server(args)
12 changes: 8 additions & 4 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.tgis_utils.args import EnvVarArgumentParser


class LoRAParserAction(argparse.Action):
Expand All @@ -23,9 +22,8 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, lora_list)


def make_arg_parser():
parser = EnvVarArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
def make_arg_parser(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
Expand Down Expand Up @@ -114,3 +112,9 @@ def make_arg_parser():

parser = AsyncEngineArgs.add_cli_args(parser)
return parser


def create_parser_for_docs() -> argparse.ArgumentParser:
parser_for_docs = argparse.ArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)
Loading

0 comments on commit 10d1d22

Please sign in to comment.