From e7c9b2a0b7cbf38ce74d0a1b3a30df7df627cd70 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 20 Mar 2024 15:06:04 -0600 Subject: [PATCH] :bento: lift grpc_server changes Signed-off-by: Joe Runde --- vllm/entrypoints/grpc/grpc_server.py | 123 +++++++++++++++++++-------- 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index a0f2b0398..3d41dcfb8 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -1,9 +1,9 @@ import argparse import inspect -import logging import time import uuid +import grpc from grpc import aio, StatusCode from typing import Optional, AsyncIterator, Dict, MutableSequence, Any, Union, Tuple, List @@ -21,6 +21,8 @@ from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.sampling_params import LogitsProcessor from vllm.tgis_utils.logits_processors import MinTokensLogitsProcessor, TypicalLogitsWarperWrapper +from vllm.transformers_utils.tokenizer import TokenizerGroup +from vllm.sequence import Logprob from vllm import AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput logger = init_logger(__name__) @@ -41,9 +43,9 @@ async def _handle_exception(e: Exception, func, *args, **kwargs): if not isinstance(e, AbortError): if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check context = kwargs.get("context", None) or args[-1] - logging.exception(f"{func.__name__} caused GPU OOM error") + logger.exception(f"{func.__name__} caused GPU OOM error") await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e)) - logging.exception(f"{func.__name__} failed") + logger.exception(f"{func.__name__} failed") raise e @@ -71,9 +73,10 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer): def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace): self.engine: AsyncLLMEngine = engine + self.tokenizer_group: TokenizerGroup = engine.get_tokenizer_group() self.tokenizer: Union[ PreTrainedTokenizer, - PreTrainedTokenizerFast] = engine.engine.tokenizer.tokenizer + PreTrainedTokenizerFast] = self.tokenizer_group.tokenizer self.config: ModelConfig = None self.max_max_new_tokens = args.max_new_tokens @@ -203,6 +206,7 @@ def _convert_input_details( result.prompt_token_ids, result.prompt_logprobs, resp_options.token_logprobs, + resp_options.token_ranks, resp_options.top_n_tokens, response.input_tokens, ) @@ -237,6 +241,7 @@ def _convert_output(self, output.token_ids, output.logprobs, resp_options.token_logprobs, + resp_options.token_ranks, resp_options.top_n_tokens, response.tokens, token_start_offset, @@ -261,9 +266,6 @@ async def _validate_and_convert_params( if params.decoding.HasField("length_penalty"): raise ValueError( "decoding.length_penalty parameter not yet supported") - if resp_options.token_ranks: - raise ValueError( - "response.token_ranks option not yet supported") # default max may be limited further in later processing max_new_tokens: Optional[int] = None @@ -293,7 +295,7 @@ async def _validate_and_convert_params( # TODO more parameter validation - logprobs = 1 if resp_options.token_logprobs else 0 + logprobs = 1 if resp_options.token_logprobs or resp_options.token_ranks else 0 top_n_tokens = resp_options.top_n_tokens if top_n_tokens: if top_n_tokens > MAX_TOP_N_TOKENS: @@ -397,8 +399,9 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool, def _convert_tokens( self, token_ids: list[int], - logprobs_list: Optional[list[Dict[int, float]]], + logprobs_list: Optional[list[Dict[int, Logprob]]], include_logprobs: bool, + include_ranks: bool, top_n_tokens: int, token_infos: MutableSequence[TokenInfo], # OUT token_start_offset: int = 0, @@ -407,24 +410,34 @@ def _convert_tokens( token_ids = token_ids[token_start_offset:] if logprobs_list is not None: logprobs_list = logprobs_list[token_start_offset:] + #TODO later use get_lora_tokenizer here token_texts = self.tokenizer.convert_ids_to_tokens(token_ids) for i, text in enumerate(token_texts): token_info = TokenInfo(text=text) if logprobs_list is not None: logprobs = logprobs_list[i] - if include_logprobs: - token_info.logprob = logprobs[token_ids[i]] - if top_n_tokens: - items = sorted(logprobs.items(), - key=lambda item: item[1], - reverse=True)[:top_n_tokens] - tt_texts = self.tokenizer.convert_ids_to_tokens( - [tid for tid, _ in items]) - token_info.top_tokens.extend( - TokenInfo.TopToken( - text=tt_text, - logprob=logprob, - ) for tt_text, (_, logprob) in zip(tt_texts, items)) + # Logprobs entry will be None for first prompt token + if logprobs is not None: + if include_logprobs or include_ranks: + logprob = logprobs[token_ids[i]] + if include_logprobs: + token_info.logprob = logprob.logprob + if include_ranks: + token_info.rank = logprob.rank + if top_n_tokens: + items = sorted(logprobs.items(), + key=lambda item: item[1].logprob, + reverse=True)[:top_n_tokens] + #TODO later use get_lora_tokenizer here + tt_texts = self.tokenizer.convert_ids_to_tokens( + [tid for tid, _ in items]) + token_info.top_tokens.extend( + TokenInfo.TopToken( + text=tt_text, + logprob=(logprob.logprob + if include_logprobs else None), + ) + for tt_text, (_, logprob) in zip(tt_texts, items)) token_infos.append(token_info) async def _validate_prompt_and_tokenize( @@ -438,7 +451,8 @@ async def _validate_prompt_and_tokenize( if truncate_input_tokens is not None else {} max_model_len = self.config.max_model_len - input_ids = self.tokenizer(prompt, **tokenize_kwargs).input_ids + input_ids = await self.tokenizer_group.encode_async( + prompt, **tokenize_kwargs) token_num = len(input_ids) if token_num >= max_model_len: @@ -469,17 +483,18 @@ async def _validate_prompt_and_tokenize( @log_rpc_handler_errors async def Tokenize(self, request: BatchedTokenizeRequest, context: ServicerContext) -> BatchedTokenizeResponse: - strings = [req.text for req in request.requests] + responses: List[TokenizeResponse] = [] - #TODO check skip special tokens behaviour (& compare with TGIS) - batch_encoding = self.tokenizer(strings) # TODO + #TODO maybe parallelize, also move convert_ids_to_tokens into the other threads + for req in request.requests: + token_ids = await self.tokenizer_group.encode_async(req.text) + responses.append( + TokenizeResponse( + token_count=len(token_ids), + tokens=None if not request.return_tokens else + self.tokenizer.convert_ids_to_tokens(token_ids))) - return BatchedTokenizeResponse(responses=[ - TokenizeResponse(token_count=len(tokens), - tokens=None if not request.return_tokens else self - .tokenizer.convert_ids_to_tokens(tokens)) - for tokens in batch_encoding.input_ids - ]) + return BatchedTokenizeResponse(responses=responses) @log_rpc_handler_errors async def ModelInfo(self, request: ModelInfoRequest, @@ -494,6 +509,11 @@ async def ModelInfo(self, request: ModelInfoRequest, async def start_grpc_server(engine: AsyncLLMEngine, args: argparse.Namespace) -> aio.Server: + + # Log memory summary after model is loaded + from torch.cuda import memory_summary + logger.info(memory_summary(engine.engine.device_config.device)) + server = aio.server() service = TextGenerationService(engine, args) await service._post_init() @@ -511,10 +531,43 @@ async def start_grpc_server(engine: AsyncLLMEngine, host = "0.0.0.0" if args.host is None else args.host listen_on = f"{host}:{args.grpc_port}" + ssl_keyfile = args.ssl_keyfile + ssl_certfile = args.ssl_certfile + ssl_ca_certs = args.ssl_ca_certs + + if ssl_keyfile and ssl_certfile: + require_client_auth = False + try: + with open(ssl_keyfile, "rb") as f: + ssl_key = f.read() + except Exception as e: + raise ValueError( + f"Error reading `ssl_keyfile` file: {ssl_keyfile}") from e + try: + with open(ssl_certfile, "rb") as f: + ssl_cert = f.read() + except Exception as e: + raise ValueError( + f"Error reading `ssl_certfile` file: {ssl_certfile}") from e + if ssl_ca_certs: + require_client_auth = True + try: + with open(ssl_ca_certs, "rb") as f: + root_certificates = f.read() + except Exception as e: + raise ValueError( + f"Error reading `ssl_ca_certs` file: {ssl_ca_certs}" + ) from e + else: + root_certificates = None + server_credentials = grpc.ssl_server_credentials([(ssl_key, ssl_cert)], + root_certificates, + require_client_auth) + server.add_secure_port(listen_on, server_credentials) + else: + server.add_insecure_port(listen_on) - #TODO add TLS - server.add_insecure_port(listen_on) await server.start() logger.info(f"gRPC Server started at {listen_on}") - return server + return server \ No newline at end of file