diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 3d41dcfb8..3e0b7888f 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -6,7 +6,8 @@ import grpc from grpc import aio, StatusCode -from typing import Optional, AsyncIterator, Dict, MutableSequence, Any, Union, Tuple, List +from typing import (Optional, AsyncIterator, Dict, MutableSequence, Any, Union, + Tuple, List) from grpc._cython.cygrpc import AbortError from grpc.aio import ServicerContext @@ -15,15 +16,17 @@ from vllm.logger import init_logger from vllm.config import ModelConfig from vllm.entrypoints.grpc.pb import generation_pb2_grpc -from vllm.entrypoints.grpc.pb.generation_pb2 import BatchedTokenizeRequest, BatchedGenerationRequest, \ - SingleGenerationRequest, ModelInfoRequest, BatchedTokenizeResponse, TokenizeResponse, ModelInfoResponse, \ - GenerationResponse, BatchedGenerationResponse, StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions +from vllm.entrypoints.grpc.pb.generation_pb2 import ( + BatchedTokenizeRequest, BatchedGenerationRequest, SingleGenerationRequest, + ModelInfoRequest, BatchedTokenizeResponse, TokenizeResponse, + ModelInfoResponse, GenerationResponse, BatchedGenerationResponse, + StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions) from vllm.entrypoints.openai.serving_completion import merge_async_iterators -from vllm.sampling_params import LogitsProcessor -from vllm.tgis_utils.logits_processors import MinTokensLogitsProcessor, TypicalLogitsWarperWrapper +from vllm.tgis_utils.logits_processors import TypicalLogitsWarperWrapper from vllm.transformers_utils.tokenizer import TokenizerGroup from vllm.sequence import Logprob -from vllm import AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput +from vllm import (AsyncLLMEngine, SamplingParams, RequestOutput, + CompletionOutput) logger = init_logger(__name__) @@ -38,8 +41,8 @@ def with_default(value: Any, default: Any) -> Any: async def _handle_exception(e: Exception, func, *args, **kwargs): - # We don't log AbortErrors since these correspond to gRPC errors intentionally - # raised during handling of requests. + # We don't log AbortErrors since these correspond to gRPC errors + # intentionally raised during handling of requests. if not isinstance(e, AbortError): if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check context = kwargs.get("context", None) or args[-1] @@ -151,7 +154,7 @@ async def GenerateStream( truncate_input_tokens = with_default( request.params.truncate_input_tokens, None) - input_ids, max_is_token_limit = await self._validate_prompt_and_tokenize( + input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize( sampling_params, truncate_input_tokens, request.request.text, context) @@ -186,9 +189,9 @@ async def GenerateStream( time_limit_reached = True # Convert output text and token_ids to deltas - yield self._convert_output(output, resp_options, - max_is_token_limit, time_limit_reached, - last_output_length, last_token_count) + yield self._convert_output(output, resp_options, max_is_tok_limit, + time_limit_reached, last_output_length, + last_token_count) if time_limit_reached: break @@ -287,15 +290,19 @@ async def _validate_and_convert_params( raise ValueError(f"min_new_tokens ({min_new_tokens}) " f"must be <= {self.max_max_new_tokens}") - if stopping.stop_sequences and len(stopping.stop_sequences) > MAX_STOP_SEQS or \ - not all(0 < len(ss) <= MAX_STOP_SEQ_LENGTH for ss in stopping.stop_sequences): + if stopping.stop_sequences and ( + len(stopping.stop_sequences) > MAX_STOP_SEQS) or \ + not all(0 < len(ss) <= MAX_STOP_SEQ_LENGTH + for ss in stopping.stop_sequences): raise ValueError( - f"can specify at most {MAX_STOP_SEQS} non-empty stop sequences, " - f"each not more than {MAX_STOP_SEQ_LENGTH} UTF8 bytes") + f"can specify at most {MAX_STOP_SEQS} non-empty stop " + f"sequences, each not more than {MAX_STOP_SEQ_LENGTH} " + f"UTF8 bytes") # TODO more parameter validation - logprobs = 1 if resp_options.token_logprobs or resp_options.token_ranks else 0 + logprobs = 1 if (resp_options.token_logprobs + or resp_options.token_ranks) else 0 top_n_tokens = resp_options.top_n_tokens if top_n_tokens: if top_n_tokens > MAX_TOP_N_TOKENS: @@ -312,7 +319,6 @@ async def _validate_and_convert_params( # GAPS: # - exp_decay_length_penalty - # - return ranks # NEW FUNCTION TO ADD (later) # - presence penalty, freq penalty @@ -325,20 +331,14 @@ async def _validate_and_convert_params( # - skip_special_tokens (per request) # - stop_token_ids - # use logits processors to extend the sampling methods - logits_processors: List[LogitsProcessor] = [] - if min_new_tokens > 0: - min_tokens_processor = MinTokensLogitsProcessor( - min_tokens=min_new_tokens, - # TODO: will eos_tokens_ids need to be adjusted to use the LoRA tokenizer? - eos_token_id=self.tokenizer.eos_token_id, - ) - logits_processors.append(min_tokens_processor) - - # to match TGIS, only including typical_p processing when using sampling + # to match TGIS, only including typical_p processing + # when using sampling if not greedy and 0.0 < sampling.typical_p < 1.0: - logits_processors.append( - TypicalLogitsWarperWrapper(mass=sampling.typical_p)) + logits_processors = [ + TypicalLogitsWarperWrapper(mass=sampling.typical_p) + ] + else: + logits_processors = None time_limit_millis = stopping.time_limit_millis deadline = time.time( @@ -346,18 +346,22 @@ async def _validate_and_convert_params( sampling_params = SamplingParams( logprobs=logprobs, - prompt_logprobs=logprobs if resp_options.input_tokens else None, + prompt_logprobs=logprobs + if resp_options.input_tokens else None, max_tokens=max_new_tokens, min_tokens=min_new_tokens, - temperature=with_default(sampling.temperature, 1.0) if not greedy else 0.0, + temperature=with_default(sampling.temperature, 1.0) + if not greedy else 0.0, top_k=with_default(sampling.top_k, -1), top_p=with_default(sampling.top_p, 1.0), seed=sampling.seed if sampling.HasField("seed") else None, - repetition_penalty=with_default(params.decoding.repetition_penalty, 1.0), + repetition_penalty=with_default( + params.decoding.repetition_penalty, 1.0), logits_processors=logits_processors, stop=with_default(stopping.stop_sequences, None), - include_stop_str_in_output=stopping.include_stop_sequence \ - if stopping.HasField("include_stop_sequence") else self.default_include_stop_seqs, + include_stop_str_in_output=stopping.include_stop_sequence + if stopping.HasField("include_stop_sequence") else + self.default_include_stop_seqs, skip_special_tokens=self.skip_special_tokens, ) except ValueError as e: @@ -372,9 +376,11 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool, finish_reason = output.finish_reason stop_sequence = None if finish_reason is None: - stop_reason = StopReason.TIME_LIMIT if time_limit_reached else StopReason.NOT_FINISHED + stop_reason = StopReason.TIME_LIMIT if ( + time_limit_reached) else StopReason.NOT_FINISHED elif finish_reason == "length": - stop_reason = StopReason.TOKEN_LIMIT if max_is_token_limit else StopReason.MAX_TOKENS + stop_reason = StopReason.TOKEN_LIMIT if ( + max_is_token_limit) else StopReason.MAX_TOKENS elif finish_reason == "stop": stop_reason = StopReason.STOP_SEQUENCE # TODO depends on https://github.com/vllm-project/vllm/pull/2976 @@ -447,7 +453,8 @@ async def _validate_prompt_and_tokenize( prompt: Optional[str], context: ServicerContext, ) -> Tuple[List[int], bool]: - tokenize_kwargs = {"truncation": True, "max_length": truncate_input_tokens} \ + tokenize_kwargs = {"truncation": True, + "max_length": truncate_input_tokens} \ if truncate_input_tokens is not None else {} max_model_len = self.config.max_model_len @@ -463,8 +470,8 @@ async def _validate_prompt_and_tokenize( if token_num + min_new_tokens > max_model_len: await context.abort( StatusCode.INVALID_ARGUMENT, - f"input tokens ({token_num}) plus min_new_tokens ({min_new_tokens}) must be <= {max_model_len}" - ) + f"input tokens ({token_num}) plus min_new_tokens " + f"({min_new_tokens}) must be <= {max_model_len}") max_new_tokens: Optional[int] = sampling_params.max_tokens max_is_token_limit = False @@ -485,7 +492,8 @@ async def Tokenize(self, request: BatchedTokenizeRequest, context: ServicerContext) -> BatchedTokenizeResponse: responses: List[TokenizeResponse] = [] - #TODO maybe parallelize, also move convert_ids_to_tokens into the other threads + #TODO maybe parallelize, also move convert_ids_to_tokens + # into the other threads for req in request.requests: token_ids = await self.tokenizer_group.encode_async(req.text) responses.append( @@ -524,7 +532,8 @@ async def start_grpc_server(engine: AsyncLLMEngine, #TODO add reflection # SERVICE_NAMES = ( - # generation_pb2.DESCRIPTOR.services_by_name["GenerationService"].full_name, + # generation_pb2.DESCRIPTOR.services_by_name["GenerationService"] + # .full_name, # reflection.SERVICE_NAME, # ) # reflection.enable_server_reflection(SERVICE_NAMES, server) @@ -570,4 +579,4 @@ async def start_grpc_server(engine: AsyncLLMEngine, await server.start() logger.info(f"gRPC Server started at {listen_on}") - return server \ No newline at end of file + return server diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f584389ec..5c36ee631 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,3 +1,4 @@ +import argparse import asyncio from contextlib import asynccontextmanager import os @@ -67,7 +68,6 @@ def parse_args(): return parsed_args - # Add prometheus asgi middleware to route /metrics requests metrics_app = make_asgi_app() app.mount("/metrics", metrics_app) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7311c54da..c3a5df67e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,6 +12,7 @@ from vllm.entrypoints.openai.serving_engine import LoRA from vllm.tgis_utils.args import EnvVarArgumentParser + class LoRAParserAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): diff --git a/vllm/tgis_utils/args.py b/vllm/tgis_utils/args.py index a7ad4d643..5fc924d90 100644 --- a/vllm/tgis_utils/args.py +++ b/vllm/tgis_utils/args.py @@ -82,7 +82,9 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace: args.model = args.model_name if args.max_sequence_length is not None: if args.max_model_len not in (None, args.max_sequence_length): - raise ValueError("Inconsistent max_model_len and max_sequence_length arg values") + raise ValueError( + "Inconsistent max_model_len and max_sequence_length arg values" + ) args.max_model_len = args.max_sequence_length if args.dtype_str is not None: if args.dtype not in (None, 'auto', args.dtype_str): @@ -96,11 +98,12 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace: if args.num_gpus is not None and args.num_shard is not None \ and args.num_gpus != args.num_shard: raise ValueError("Inconsistent num_gpus and num_shard arg values") - num_gpus = args.num_gpus if args.num_gpus is not None else args.num_shard + num_gpus = args.num_gpus if (args.num_gpus + is not None) else args.num_shard if args.tensor_parallel_size not in [None, 1, num_gpus]: raise ValueError( - "Inconsistent tensor_parallel_size and num_gpus/num_shard arg values" - ) + "Inconsistent tensor_parallel_size and num_gpus/num_shard " + "arg values") args.tensor_parallel_size = num_gpus return args diff --git a/vllm/tgis_utils/logits_processors.py b/vllm/tgis_utils/logits_processors.py index eb55e27ce..51b78de1b 100644 --- a/vllm/tgis_utils/logits_processors.py +++ b/vllm/tgis_utils/logits_processors.py @@ -1,23 +1,9 @@ -from typing import List, Union +from typing import List import torch from transformers.generation.logits_process import TypicalLogitsWarper -class MinTokensLogitsProcessor: - - def __init__(self, min_tokens: int, eos_token_id: Union[int, List[int]]): - self.min_tokens = min_tokens - self.eos_token_ids = torch.tensor(eos_token_id) - - def __call__(self, token_ids: List[int], - logits: torch.tensor) -> torch.tensor: - # token_ids is only output tokens - if len(token_ids) < self.min_tokens: - logits[self.eos_token_ids] = -float("inf") - return logits - - class TypicalLogitsWarperWrapper: def __init__(self, mass: float):