-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathservice.py
110 lines (90 loc) · 4.14 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
import logging, typing, uuid
import bentoml, fastapi, typing_extensions, annotated_types
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ENGINE_CONFIG = {
'model': 'Qwen/Qwen2.5-72B-Instruct',
'max_model_len': 2048,
'tensor_parallel_size': 2,
'enable_prefix_caching': True,
}
MAX_TOKENS = 1024
openai_api_app = fastapi.FastAPI()
@bentoml.asgi_app(openai_api_app, path='/v1')
@bentoml.service(
name='bentovllm-qwen2.5-72b-instruct-service',
traffic={'timeout': 300},
resources={'gpu': 2, 'gpu_type': 'nvidia-a100-80gb'},
labels={'owner': 'bentoml-team', 'type': 'prebuilt'},
image=bentoml.images.PythonImage(python_version='3.11', lock_python_packages=False)
.requirements_file('requirements.txt')
.run('uv pip install flashinfer-python --find-links https://flashinfer.ai/whl/cu124/torch2.5'),
)
class VLLM:
model_id = ENGINE_CONFIG['model']
model = bentoml.models.HuggingFaceModel(model_id, exclude=['*.pth', '*.pt'])
def __init__(self):
from openai import AsyncOpenAI
self.openai = AsyncOpenAI(base_url='http://127.0.0.1:3000/v1', api_key='dummy')
@bentoml.on_startup
async def init_engine(self) -> None:
import vllm.entrypoints.openai.api_server as vllm_api_server
from vllm.utils import FlexibleArgumentParser
from vllm.entrypoints.openai.cli_args import make_arg_parser
args = make_arg_parser(FlexibleArgumentParser()).parse_args([])
args.model = self.model
args.disable_log_requests = True
args.max_log_len = 1000
args.served_model_name = [self.model_id]
args.request_logger = None
args.disable_log_stats = True
for key, value in ENGINE_CONFIG.items():
setattr(args, key, value)
router = fastapi.APIRouter(lifespan=vllm_api_server.lifespan)
OPENAI_ENDPOINTS = [
['/chat/completions', vllm_api_server.create_chat_completion, ['POST']],
['/models', vllm_api_server.show_available_models, ['GET']],
]
for route, endpoint, methods in OPENAI_ENDPOINTS:
router.add_api_route(path=route, endpoint=endpoint, methods=methods, include_in_schema=True)
openai_api_app.include_router(router)
self.engine_context = vllm_api_server.build_async_engine_client(args)
self.engine = await self.engine_context.__aenter__()
self.model_config = await self.engine.get_model_config()
self.tokenizer = await self.engine.get_tokenizer()
args.enable_auto_tool_choice = True
args.tool_call_parser = 'llama3_json'
await vllm_api_server.init_app_state(self.engine, self.model_config, openai_api_app.state, args)
@bentoml.on_shutdown
async def teardown_engine(self):
await self.engine_context.__aexit__(GeneratorExit, None, None)
@bentoml.api
async def generate(
self,
prompt: str = 'Who are you? Please respond in pirate speak!',
max_tokens: typing_extensions.Annotated[
int, annotated_types.Ge(128), annotated_types.Le(MAX_TOKENS)
] = MAX_TOKENS,
) -> typing.AsyncGenerator[str, None]:
from vllm import SamplingParams, TokensPrompt
from vllm.entrypoints.chat_utils import parse_chat_messages, apply_hf_chat_template
params = SamplingParams(max_tokens=max_tokens)
messages = [dict(role='user', content=[dict(type='text', text=prompt)])]
conversation, _ = parse_chat_messages(messages, self.model_config, self.tokenizer, content_format='string')
prompt = TokensPrompt(
prompt_token_ids=apply_hf_chat_template(
self.tokenizer,
conversation=conversation,
add_generation_prompt=True,
continue_final_message=False,
chat_template=None,
tokenize=True,
)
)
stream = self.engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params)
cursor = 0
async for request_output in stream:
text = request_output.outputs[0].text
yield text[cursor:]
cursor = len(text)