Skip to content

Commit

Permalink
Merge pull request #292 from MeetKai/vllm-lora-serving
Browse files Browse the repository at this point in the history
Vllm lora serving
  • Loading branch information
jeffreymeetkai authored Nov 15, 2024
2 parents 2afd2bb + ff3ffe6 commit bdd6eb0
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 76 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,43 @@ python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 -
python server_sglang.py --model-path "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2
```

#### LoRA Support (Currently Only in vLLM)

Similar to [LoRA in vLLM](https://docs.vllm.ai/en/latest/models/lora.html), our server supports serving LoRA adapters both at startup and dynamically.

To serve a LoRA adapter at startup, run the server with the `--lora-modules` argument:

```shell
python server_vllm.py --model {BASE_MODEL} --enable-lora --lora-modules {name}={path} {name}={path} --host 0.0.0.0 --port 8000
```

To serve a LoRA adapter dynamically, use the `/v1/load_lora_adapter` endpoint:
```shell
python server_vllm.py --model {BASE_MODEL} --enable-lora --host 0.0.0.0 --port 8000
# Load a LoRA adapter dynamically
curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "my_lora",
"lora_path": "/path/to/my_lora_adapter"
}'
# Example chat request to lora adapter
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "my_lora",
"messages": [...],
"tools": [...],
"tool_choice": "auto"
}'
# Unload a LoRA adapter dynamically
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "my_lora"
}'
```


### Grammar Sampling (Only in vLLM)

Expand Down
12 changes: 9 additions & 3 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList

from functionary.openai_types import Function
from functionary.openai_types import ChatCompletionRequest, Function
from functionary.prompt_template.prompt_utils import enforce_tool_choice


Expand Down Expand Up @@ -65,13 +65,19 @@ def create_error_response(
)


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model not in served_model:
async def check_all_errors(
request: ChatCompletionRequest, served_model: List, served_loras: List = []
) -> Optional[JSONResponse]:

if request.model not in served_model and request.model not in [
lora.lora_name for lora in served_loras
]:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
param=None,
)

if request.tools and request.functions:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
Expand Down
102 changes: 100 additions & 2 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

from fastapi import BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.protocol import (
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest,
)
from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from vllm.utils import AtomicCounter, random_uuid

from functionary.inference_stream import generate_openai_format_from_stream_async
from functionary.inference_utils import (
Expand Down Expand Up @@ -83,19 +89,109 @@ async def check_length(request, input_ids, model_config):
return None


async def process_load_lora_adapter(
request: LoadLoraAdapterRequest,
served_loras: List[LoRARequest],
lora_id_counter: AtomicCounter,
) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]:

# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return (
create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="Both 'lora_name' and 'lora_path' must be provided.",
param=None,
),
served_loras,
)
# Check if the lora adapter with the given name already exists
if any(
lora_request.lora_name == request.lora_name for lora_request in served_loras
):
return (
create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"The lora adapter '{request.lora_name}' has already been loaded.",
param=None,
),
served_loras,
)

lora_name, lora_path = request.lora_name, request.lora_path
unique_id = lora_id_counter.inc(1)
served_loras.append(
LoRARequest(lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path)
)

return f"Success: LoRA adapter '{lora_name}' added successfully.", served_loras


async def process_unload_lora_adapter(
request: UnloadLoraAdapterRequest, served_loras: List[LoRARequest]
) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return (
create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="either 'lora_name' and 'lora_int_id' needs to be provided.",
param=None,
),
served_loras,
)

# Check if the lora adapter with the given name exists
if not any(
lora_request.lora_name == request.lora_name for lora_request in served_loras
):
return (
create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"The lora adapter '{request.lora_name}' cannot be found.",
param=None,
),
served_loras,
)

lora_name = request.lora_name
served_loras = [
lora_request
for lora_request in served_loras
if lora_request.lora_name != lora_name
]

return f"Success: LoRA adapter '{lora_name}' removed successfully.", served_loras


def get_lora_adapter(
request: ChatCompletionRequest, served_loras: List[LoRARequest]
) -> Optional[LoRARequest]:
for lora in served_loras:
if request.model == lora.lora_name:
return lora
return None


async def process_chat_completion(
request: ChatCompletionRequest,
raw_request: Optional[Request],
tokenizer: Any,
served_model: List[str],
served_loras: List[LoRARequest],
engine_model_config: Any,
enable_grammar_sampling: bool,
engine: Any,
):
error_check_ret = await check_all_errors(request, served_model)
error_check_ret = await check_all_errors(request, served_model, served_loras)
if error_check_ret is not None:
return error_check_ret

# Get the lora adapter if it exists and replace tokenizer
lora_request = get_lora_adapter(request, served_loras)
if lora_request is not None:
tokenizer = get_lora_tokenizer(lora_request)

tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request)

prompt_token_ids = prepare_messages_for_inference(
Expand Down Expand Up @@ -150,6 +246,7 @@ async def process_chat_completion(
if enable_grammar_sampling:
result_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
lora_request=lora_request,
sampling_params=sampling_params,
request_id=request_id,
tools_or_functions=tools_or_functions,
Expand All @@ -159,6 +256,7 @@ async def process_chat_completion(
else:
result_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
lora_request=lora_request,
sampling_params=sampling_params,
request_id=request_id,
)
Expand Down
2 changes: 1 addition & 1 deletion functionary/vllm_monkey_patch/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
from functionary.inference import (
get_lm_format_enforcer_vllm_logits_processor_from_tool_name,
)
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Tool
from functionary.prompt_template.prompt_utils import resolve_json_refs

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ name = "functionary"
version = "0.0.1"
description = "Chat language model that can use tools and interpret the results"
requires-python = ">=3.9"
dependencies = [
"jsonref~=1.1.0",
"json_source_map==1.0.5",
"PyYAML~=6.0.1",
]

[build-system]
requires = ["setuptools>=61.0"]
Expand All @@ -14,12 +19,8 @@ packages = ["functionary"]
[project.optional-dependencies]
vllm = [
"vllm==0.6.3.post1; sys_platform != 'darwin'",
"jsonref~=1.1.0",
"json_source_map==1.0.5",
"PyYAML~=6.0.1",
]
sglang = [
"jsonref~=1.1.0",
"python-multipart==0.0.12",
"orjson==3.10.10",
"sglang[all]==0.3.4.post1",
Expand Down
4 changes: 0 additions & 4 deletions requirements.txt

This file was deleted.

78 changes: 76 additions & 2 deletions server_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,24 @@
from fastapi.responses import Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import mount_metrics
from vllm.entrypoints.openai.protocol import ModelCard, ModelList, ModelPermission
from vllm.entrypoints.openai.protocol import (
LoadLoraAdapterRequest,
ModelCard,
ModelList,
ModelPermission,
UnloadLoraAdapterRequest,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import AtomicCounter

from functionary.openai_types import ChatCompletionRequest
from functionary.vllm_inference import process_chat_completion
from functionary.vllm_inference import (
process_chat_completion,
process_load_lora_adapter,
process_unload_lora_adapter,
)

TIMEOUT_KEEP_ALIVE = 5 # seconds

Expand All @@ -46,6 +58,8 @@


served_model = []
served_loras = []
lora_id_counter = AtomicCounter(0)
app = fastapi.FastAPI()


Expand All @@ -72,6 +86,22 @@ async def show_available_models():
id=served_model, root=served_model, permission=[ModelPermission()]
)
)

for lora in served_loras:
parent = (
lora.base_model_name
if lora.base_model_name
else (served_model[0] if isinstance(served_model, list) else served_model)
)
model_cards.append(
ModelCard(
id=lora.lora_name,
root=lora.lora_path,
parent=parent,
permission=[ModelPermission()],
)
)

return ModelList(data=model_cards)


Expand All @@ -93,12 +123,39 @@ async def create_chat_completion(raw_request: Request):
raw_request=raw_request,
tokenizer=tokenizer,
served_model=served_model,
served_loras=served_loras,
engine_model_config=engine_model_config,
enable_grammar_sampling=args.grammar_sampling,
engine=engine,
)


@app.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest):
global served_loras

error, served_loras = await process_load_lora_adapter(
request, served_loras, lora_id_counter
)
if not isinstance(error, str):
return error

# `error` is the success message if it is a string
return Response(status_code=200, content=error)


@app.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
global served_loras

error, served_loras = await process_unload_lora_adapter(request, served_loras)
if not isinstance(error, str):
return error

# `error` is the success message if it is a string
return Response(status_code=200, content=error)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
Expand All @@ -117,6 +174,13 @@ async def create_chat_completion(raw_request: Request):
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument(
"--lora-modules",
nargs="*",
type=str,
help="LoRA modules in the format 'name=path name=path ...'",
default=[],
)
parser.add_argument(
"--enable-grammar-sampling",
dest="grammar_sampling",
Expand Down Expand Up @@ -156,6 +220,16 @@ async def create_chat_completion(raw_request: Request):
if args.served_model_name is not None:
served_model += args.served_model_name

for lora_module in args.lora_modules:
lora_name, lora_path = lora_module.split("=")
served_loras.append(
LoRARequest(
lora_name=lora_name,
lora_int_id=lora_id_counter.inc(1),
lora_path=lora_path,
)
)

engine_args = AsyncEngineArgs.from_cli_args(args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(
Expand Down
Loading

0 comments on commit bdd6eb0

Please sign in to comment.