Skip to content

Commit

Permalink
🍱 lift grpc_server changes
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Mar 20, 2024
1 parent 1479404 commit e7c9b2a
Showing 1 changed file with 88 additions and 35 deletions.
123 changes: 88 additions & 35 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import inspect
import logging
import time
import uuid

import grpc
from grpc import aio, StatusCode

from typing import Optional, AsyncIterator, Dict, MutableSequence, Any, Union, Tuple, List
Expand All @@ -21,6 +21,8 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.sampling_params import LogitsProcessor
from vllm.tgis_utils.logits_processors import MinTokensLogitsProcessor, TypicalLogitsWarperWrapper
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Logprob
from vllm import AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput

logger = init_logger(__name__)
Expand All @@ -41,9 +43,9 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
if not isinstance(e, AbortError):
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
logging.exception(f"{func.__name__} caused GPU OOM error")
logger.exception(f"{func.__name__} caused GPU OOM error")
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
logging.exception(f"{func.__name__} failed")
logger.exception(f"{func.__name__} failed")
raise e


Expand Down Expand Up @@ -71,9 +73,10 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer):

def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.engine: AsyncLLMEngine = engine
self.tokenizer_group: TokenizerGroup = engine.get_tokenizer_group()
self.tokenizer: Union[
PreTrainedTokenizer,
PreTrainedTokenizerFast] = engine.engine.tokenizer.tokenizer
PreTrainedTokenizerFast] = self.tokenizer_group.tokenizer
self.config: ModelConfig = None

self.max_max_new_tokens = args.max_new_tokens
Expand Down Expand Up @@ -203,6 +206,7 @@ def _convert_input_details(
result.prompt_token_ids,
result.prompt_logprobs,
resp_options.token_logprobs,
resp_options.token_ranks,
resp_options.top_n_tokens,
response.input_tokens,
)
Expand Down Expand Up @@ -237,6 +241,7 @@ def _convert_output(self,
output.token_ids,
output.logprobs,
resp_options.token_logprobs,
resp_options.token_ranks,
resp_options.top_n_tokens,
response.tokens,
token_start_offset,
Expand All @@ -261,9 +266,6 @@ async def _validate_and_convert_params(
if params.decoding.HasField("length_penalty"):
raise ValueError(
"decoding.length_penalty parameter not yet supported")
if resp_options.token_ranks:
raise ValueError(
"response.token_ranks option not yet supported")

# default max may be limited further in later processing
max_new_tokens: Optional[int] = None
Expand Down Expand Up @@ -293,7 +295,7 @@ async def _validate_and_convert_params(

# TODO more parameter validation

logprobs = 1 if resp_options.token_logprobs else 0
logprobs = 1 if resp_options.token_logprobs or resp_options.token_ranks else 0
top_n_tokens = resp_options.top_n_tokens
if top_n_tokens:
if top_n_tokens > MAX_TOP_N_TOKENS:
Expand Down Expand Up @@ -397,8 +399,9 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
def _convert_tokens(
self,
token_ids: list[int],
logprobs_list: Optional[list[Dict[int, float]]],
logprobs_list: Optional[list[Dict[int, Logprob]]],
include_logprobs: bool,
include_ranks: bool,
top_n_tokens: int,
token_infos: MutableSequence[TokenInfo], # OUT
token_start_offset: int = 0,
Expand All @@ -407,24 +410,34 @@ def _convert_tokens(
token_ids = token_ids[token_start_offset:]
if logprobs_list is not None:
logprobs_list = logprobs_list[token_start_offset:]
#TODO later use get_lora_tokenizer here
token_texts = self.tokenizer.convert_ids_to_tokens(token_ids)
for i, text in enumerate(token_texts):
token_info = TokenInfo(text=text)
if logprobs_list is not None:
logprobs = logprobs_list[i]
if include_logprobs:
token_info.logprob = logprobs[token_ids[i]]
if top_n_tokens:
items = sorted(logprobs.items(),
key=lambda item: item[1],
reverse=True)[:top_n_tokens]
tt_texts = self.tokenizer.convert_ids_to_tokens(
[tid for tid, _ in items])
token_info.top_tokens.extend(
TokenInfo.TopToken(
text=tt_text,
logprob=logprob,
) for tt_text, (_, logprob) in zip(tt_texts, items))
# Logprobs entry will be None for first prompt token
if logprobs is not None:
if include_logprobs or include_ranks:
logprob = logprobs[token_ids[i]]
if include_logprobs:
token_info.logprob = logprob.logprob
if include_ranks:
token_info.rank = logprob.rank
if top_n_tokens:
items = sorted(logprobs.items(),
key=lambda item: item[1].logprob,
reverse=True)[:top_n_tokens]
#TODO later use get_lora_tokenizer here
tt_texts = self.tokenizer.convert_ids_to_tokens(
[tid for tid, _ in items])
token_info.top_tokens.extend(
TokenInfo.TopToken(
text=tt_text,
logprob=(logprob.logprob
if include_logprobs else None),
)
for tt_text, (_, logprob) in zip(tt_texts, items))
token_infos.append(token_info)

async def _validate_prompt_and_tokenize(
Expand All @@ -438,7 +451,8 @@ async def _validate_prompt_and_tokenize(
if truncate_input_tokens is not None else {}

max_model_len = self.config.max_model_len
input_ids = self.tokenizer(prompt, **tokenize_kwargs).input_ids
input_ids = await self.tokenizer_group.encode_async(
prompt, **tokenize_kwargs)
token_num = len(input_ids)

if token_num >= max_model_len:
Expand Down Expand Up @@ -469,17 +483,18 @@ async def _validate_prompt_and_tokenize(
@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
strings = [req.text for req in request.requests]
responses: List[TokenizeResponse] = []

#TODO check skip special tokens behaviour (& compare with TGIS)
batch_encoding = self.tokenizer(strings) # TODO
#TODO maybe parallelize, also move convert_ids_to_tokens into the other threads
for req in request.requests:
token_ids = await self.tokenizer_group.encode_async(req.text)
responses.append(
TokenizeResponse(
token_count=len(token_ids),
tokens=None if not request.return_tokens else
self.tokenizer.convert_ids_to_tokens(token_ids)))

return BatchedTokenizeResponse(responses=[
TokenizeResponse(token_count=len(tokens),
tokens=None if not request.return_tokens else self
.tokenizer.convert_ids_to_tokens(tokens))
for tokens in batch_encoding.input_ids
])
return BatchedTokenizeResponse(responses=responses)

@log_rpc_handler_errors
async def ModelInfo(self, request: ModelInfoRequest,
Expand All @@ -494,6 +509,11 @@ async def ModelInfo(self, request: ModelInfoRequest,

async def start_grpc_server(engine: AsyncLLMEngine,
args: argparse.Namespace) -> aio.Server:

# Log memory summary after model is loaded
from torch.cuda import memory_summary
logger.info(memory_summary(engine.engine.device_config.device))

server = aio.server()
service = TextGenerationService(engine, args)
await service._post_init()
Expand All @@ -511,10 +531,43 @@ async def start_grpc_server(engine: AsyncLLMEngine,

host = "0.0.0.0" if args.host is None else args.host
listen_on = f"{host}:{args.grpc_port}"
ssl_keyfile = args.ssl_keyfile
ssl_certfile = args.ssl_certfile
ssl_ca_certs = args.ssl_ca_certs

if ssl_keyfile and ssl_certfile:
require_client_auth = False
try:
with open(ssl_keyfile, "rb") as f:
ssl_key = f.read()
except Exception as e:
raise ValueError(
f"Error reading `ssl_keyfile` file: {ssl_keyfile}") from e
try:
with open(ssl_certfile, "rb") as f:
ssl_cert = f.read()
except Exception as e:
raise ValueError(
f"Error reading `ssl_certfile` file: {ssl_certfile}") from e
if ssl_ca_certs:
require_client_auth = True
try:
with open(ssl_ca_certs, "rb") as f:
root_certificates = f.read()
except Exception as e:
raise ValueError(
f"Error reading `ssl_ca_certs` file: {ssl_ca_certs}"
) from e
else:
root_certificates = None
server_credentials = grpc.ssl_server_credentials([(ssl_key, ssl_cert)],
root_certificates,
require_client_auth)
server.add_secure_port(listen_on, server_credentials)
else:
server.add_insecure_port(listen_on)

#TODO add TLS
server.add_insecure_port(listen_on)
await server.start()
logger.info(f"gRPC Server started at {listen_on}")

return server
return server

0 comments on commit e7c9b2a

Please sign in to comment.