From 88d3697a269e2e2b45a61bdb668367442b7336f1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 22 Mar 2024 15:03:59 -0700 Subject: [PATCH] Tokenization-related updates to grpc_server layer --- vllm/entrypoints/grpc/grpc_server.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 3e0b7888f..88fb6578b 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -23,7 +23,7 @@ StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions) from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.tgis_utils.logits_processors import TypicalLogitsWarperWrapper -from vllm.transformers_utils.tokenizer import TokenizerGroup +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.sequence import Logprob from vllm import (AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput) @@ -76,10 +76,11 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer): def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace): self.engine: AsyncLLMEngine = engine - self.tokenizer_group: TokenizerGroup = engine.get_tokenizer_group() - self.tokenizer: Union[ - PreTrainedTokenizer, - PreTrainedTokenizerFast] = self.tokenizer_group.tokenizer + + # These set in _post_init() + self.tokenizer_group: BaseTokenizerGroup = None + self.tokenizer: Union[PreTrainedTokenizer, + PreTrainedTokenizerFast] = None self.config: ModelConfig = None self.max_max_new_tokens = args.max_new_tokens @@ -88,6 +89,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace): async def _post_init(self): self.config = await self.engine.get_model_config() + self.tokenizer_group = await self.engine.get_tokenizer_group() + self.tokenizer = await self.engine.get_tokenizer() @log_rpc_handler_errors async def Generate(self, request: BatchedGenerationRequest,