Skip to content

Commit

Permalink
compat with old vllm version
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Jun 7, 2024
1 parent 87097bf commit 016bdff
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
8 changes: 7 additions & 1 deletion api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ def create_hf_llm():
def create_vllm_engine():
""" get vllm generate engine for chat or completion. """
try:
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from api.core.vllm_engine import VllmEngine, LoRA
except ImportError:
raise ValueError("VLLM engine not available")

vllm_version = vllm.__version__

include = {
"tokenizer_mode",
"trust_remote_code",
Expand All @@ -106,11 +109,14 @@ def create_vllm_engine():
"gpu_memory_utilization",
"max_num_seqs",
"enforce_eager",
"max_seq_len_to_capture",
"max_loras",
"max_lora_rank",
"lora_extra_vocab_size",
}

if vllm_version >= "0.4.3":
include.add("max_seq_len_to_capture")

kwargs = dictify(SETTINGS, include=include)
engine_args = AsyncEngineArgs(
model=SETTINGS.model_path,
Expand Down
13 changes: 7 additions & 6 deletions api/vllm_routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import AsyncIterator

import anyio
import vllm
from fastapi import APIRouter, Depends, status
from fastapi import HTTPException, Request
from loguru import logger
Expand Down Expand Up @@ -38,6 +39,7 @@
)

chat_router = APIRouter(prefix="/chat")
vllm_version = vllm.__version__


def get_engine():
Expand Down Expand Up @@ -105,17 +107,16 @@ async def create_chat_completion(
try:
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor

decoding_config = await engine.model.get_decoding_config()

try:
if vllm_version >= "0.4.3":
decoding_config = await engine.model.get_decoding_config()
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request.guided_decoding_backend or decoding_config.guided_decoding_backend,
request,
engine.tokenizer,
)
)
except TypeError:
else:
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request,
Expand All @@ -128,7 +129,7 @@ async def create_chat_completion(
except ImportError:
pass

try:
if vllm_version >= "0.4.3":
result_generator = engine.model.generate(
{
"prompt": prompt if isinstance(prompt, str) else None,
Expand All @@ -138,7 +139,7 @@ async def create_chat_completion(
request_id,
lora_request,
)
except TypeError:
else:
result_generator = engine.model.generate(
prompt if isinstance(prompt, str) else None,
sampling_params,
Expand Down
13 changes: 7 additions & 6 deletions api/vllm_routes/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import AsyncIterator, Tuple

import anyio
import vllm
from fastapi import APIRouter, Depends
from fastapi import Request
from loguru import logger
Expand All @@ -27,6 +28,7 @@
)

completion_router = APIRouter()
vllm_version = vllm.__version__


def get_engine():
Expand Down Expand Up @@ -144,17 +146,16 @@ async def create_completion(
try:
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor

decoding_config = await engine.model.get_decoding_config()

try:
if vllm_version >= "0.4.3":
decoding_config = await engine.model.get_decoding_config()
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request.guided_decoding_backend or decoding_config.guided_decoding_backend,
request,
engine.tokenizer,
)
)
except TypeError:
else:
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request,
Expand All @@ -176,7 +177,7 @@ async def create_completion(
else:
input_ids = engine.convert_to_inputs(prompt=prompt, max_tokens=request.max_tokens)

try:
if vllm_version >= "0.4.3":
generator = engine.model.generate(
{
"prompt": prompt,
Expand All @@ -186,7 +187,7 @@ async def create_completion(
request_id,
lora_request,
)
except TypeError:
else:
generator = engine.model.generate(
prompt,
sampling_params,
Expand Down

0 comments on commit 016bdff

Please sign in to comment.