Skip to content

Commit

Permalink
Tokenization-related updates to grpc_server layer
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Mar 22, 2024
1 parent da1d3ad commit 88d3697
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions)
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.tgis_utils.logits_processors import TypicalLogitsWarperWrapper
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.sequence import Logprob
from vllm import (AsyncLLMEngine, SamplingParams, RequestOutput,
CompletionOutput)
Expand Down Expand Up @@ -76,10 +76,11 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer):

def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.engine: AsyncLLMEngine = engine
self.tokenizer_group: TokenizerGroup = engine.get_tokenizer_group()
self.tokenizer: Union[
PreTrainedTokenizer,
PreTrainedTokenizerFast] = self.tokenizer_group.tokenizer

# These set in _post_init()
self.tokenizer_group: BaseTokenizerGroup = None
self.tokenizer: Union[PreTrainedTokenizer,
PreTrainedTokenizerFast] = None
self.config: ModelConfig = None

self.max_max_new_tokens = args.max_new_tokens
Expand All @@ -88,6 +89,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer = await self.engine.get_tokenizer()

@log_rpc_handler_errors
async def Generate(self, request: BatchedGenerationRequest,
Expand Down

0 comments on commit 88d3697

Please sign in to comment.