Skip to content

Commit

Permalink
Linting, adjustments related to min_tokens updates
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Mar 20, 2024
1 parent e7c9b2a commit f9fad31
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 65 deletions.
99 changes: 54 additions & 45 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import grpc
from grpc import aio, StatusCode

from typing import Optional, AsyncIterator, Dict, MutableSequence, Any, Union, Tuple, List
from typing import (Optional, AsyncIterator, Dict, MutableSequence, Any, Union,
Tuple, List)

from grpc._cython.cygrpc import AbortError
from grpc.aio import ServicerContext
Expand All @@ -15,15 +16,17 @@
from vllm.logger import init_logger
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.pb import generation_pb2_grpc
from vllm.entrypoints.grpc.pb.generation_pb2 import BatchedTokenizeRequest, BatchedGenerationRequest, \
SingleGenerationRequest, ModelInfoRequest, BatchedTokenizeResponse, TokenizeResponse, ModelInfoResponse, \
GenerationResponse, BatchedGenerationResponse, StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions
from vllm.entrypoints.grpc.pb.generation_pb2 import (
BatchedTokenizeRequest, BatchedGenerationRequest, SingleGenerationRequest,
ModelInfoRequest, BatchedTokenizeResponse, TokenizeResponse,
ModelInfoResponse, GenerationResponse, BatchedGenerationResponse,
StopReason, TokenInfo, Parameters, DecodingMethod, ResponseOptions)
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.sampling_params import LogitsProcessor
from vllm.tgis_utils.logits_processors import MinTokensLogitsProcessor, TypicalLogitsWarperWrapper
from vllm.tgis_utils.logits_processors import TypicalLogitsWarperWrapper
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Logprob
from vllm import AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput
from vllm import (AsyncLLMEngine, SamplingParams, RequestOutput,
CompletionOutput)

logger = init_logger(__name__)

Expand All @@ -38,8 +41,8 @@ def with_default(value: Any, default: Any) -> Any:


async def _handle_exception(e: Exception, func, *args, **kwargs):
# We don't log AbortErrors since these correspond to gRPC errors intentionally
# raised during handling of requests.
# We don't log AbortErrors since these correspond to gRPC errors
# intentionally raised during handling of requests.
if not isinstance(e, AbortError):
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
Expand Down Expand Up @@ -151,7 +154,7 @@ async def GenerateStream(
truncate_input_tokens = with_default(
request.params.truncate_input_tokens, None)

input_ids, max_is_token_limit = await self._validate_prompt_and_tokenize(
input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, request.request.text,
context)

Expand Down Expand Up @@ -186,9 +189,9 @@ async def GenerateStream(
time_limit_reached = True

# Convert output text and token_ids to deltas
yield self._convert_output(output, resp_options,
max_is_token_limit, time_limit_reached,
last_output_length, last_token_count)
yield self._convert_output(output, resp_options, max_is_tok_limit,
time_limit_reached, last_output_length,
last_token_count)
if time_limit_reached:
break

Expand Down Expand Up @@ -287,15 +290,19 @@ async def _validate_and_convert_params(
raise ValueError(f"min_new_tokens ({min_new_tokens}) "
f"must be <= {self.max_max_new_tokens}")

if stopping.stop_sequences and len(stopping.stop_sequences) > MAX_STOP_SEQS or \
not all(0 < len(ss) <= MAX_STOP_SEQ_LENGTH for ss in stopping.stop_sequences):
if stopping.stop_sequences and (
len(stopping.stop_sequences) > MAX_STOP_SEQS) or \
not all(0 < len(ss) <= MAX_STOP_SEQ_LENGTH
for ss in stopping.stop_sequences):
raise ValueError(
f"can specify at most {MAX_STOP_SEQS} non-empty stop sequences, "
f"each not more than {MAX_STOP_SEQ_LENGTH} UTF8 bytes")
f"can specify at most {MAX_STOP_SEQS} non-empty stop "
f"sequences, each not more than {MAX_STOP_SEQ_LENGTH} "
f"UTF8 bytes")

# TODO more parameter validation

logprobs = 1 if resp_options.token_logprobs or resp_options.token_ranks else 0
logprobs = 1 if (resp_options.token_logprobs
or resp_options.token_ranks) else 0
top_n_tokens = resp_options.top_n_tokens
if top_n_tokens:
if top_n_tokens > MAX_TOP_N_TOKENS:
Expand All @@ -312,7 +319,6 @@ async def _validate_and_convert_params(

# GAPS:
# - exp_decay_length_penalty
# - return ranks

# NEW FUNCTION TO ADD (later)
# - presence penalty, freq penalty
Expand All @@ -325,39 +331,37 @@ async def _validate_and_convert_params(
# - skip_special_tokens (per request)
# - stop_token_ids

# use logits processors to extend the sampling methods
logits_processors: List[LogitsProcessor] = []
if min_new_tokens > 0:
min_tokens_processor = MinTokensLogitsProcessor(
min_tokens=min_new_tokens,
# TODO: will eos_tokens_ids need to be adjusted to use the LoRA tokenizer?
eos_token_id=self.tokenizer.eos_token_id,
)
logits_processors.append(min_tokens_processor)

# to match TGIS, only including typical_p processing when using sampling
# to match TGIS, only including typical_p processing
# when using sampling
if not greedy and 0.0 < sampling.typical_p < 1.0:
logits_processors.append(
TypicalLogitsWarperWrapper(mass=sampling.typical_p))
logits_processors = [
TypicalLogitsWarperWrapper(mass=sampling.typical_p)
]
else:
logits_processors = None

time_limit_millis = stopping.time_limit_millis
deadline = time.time(
) + time_limit_millis / 1000.0 if time_limit_millis > 0 else None

sampling_params = SamplingParams(
logprobs=logprobs,
prompt_logprobs=logprobs if resp_options.input_tokens else None,
prompt_logprobs=logprobs
if resp_options.input_tokens else None,
max_tokens=max_new_tokens,
min_tokens=min_new_tokens,
temperature=with_default(sampling.temperature, 1.0) if not greedy else 0.0,
temperature=with_default(sampling.temperature, 1.0)
if not greedy else 0.0,
top_k=with_default(sampling.top_k, -1),
top_p=with_default(sampling.top_p, 1.0),
seed=sampling.seed if sampling.HasField("seed") else None,
repetition_penalty=with_default(params.decoding.repetition_penalty, 1.0),
repetition_penalty=with_default(
params.decoding.repetition_penalty, 1.0),
logits_processors=logits_processors,
stop=with_default(stopping.stop_sequences, None),
include_stop_str_in_output=stopping.include_stop_sequence \
if stopping.HasField("include_stop_sequence") else self.default_include_stop_seqs,
include_stop_str_in_output=stopping.include_stop_sequence
if stopping.HasField("include_stop_sequence") else
self.default_include_stop_seqs,
skip_special_tokens=self.skip_special_tokens,
)
except ValueError as e:
Expand All @@ -372,9 +376,11 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
finish_reason = output.finish_reason
stop_sequence = None
if finish_reason is None:
stop_reason = StopReason.TIME_LIMIT if time_limit_reached else StopReason.NOT_FINISHED
stop_reason = StopReason.TIME_LIMIT if (
time_limit_reached) else StopReason.NOT_FINISHED
elif finish_reason == "length":
stop_reason = StopReason.TOKEN_LIMIT if max_is_token_limit else StopReason.MAX_TOKENS
stop_reason = StopReason.TOKEN_LIMIT if (
max_is_token_limit) else StopReason.MAX_TOKENS
elif finish_reason == "stop":
stop_reason = StopReason.STOP_SEQUENCE
# TODO depends on https://github.com/vllm-project/vllm/pull/2976
Expand Down Expand Up @@ -447,7 +453,8 @@ async def _validate_prompt_and_tokenize(
prompt: Optional[str],
context: ServicerContext,
) -> Tuple[List[int], bool]:
tokenize_kwargs = {"truncation": True, "max_length": truncate_input_tokens} \
tokenize_kwargs = {"truncation": True,
"max_length": truncate_input_tokens} \
if truncate_input_tokens is not None else {}

max_model_len = self.config.max_model_len
Expand All @@ -463,8 +470,8 @@ async def _validate_prompt_and_tokenize(
if token_num + min_new_tokens > max_model_len:
await context.abort(
StatusCode.INVALID_ARGUMENT,
f"input tokens ({token_num}) plus min_new_tokens ({min_new_tokens}) must be <= {max_model_len}"
)
f"input tokens ({token_num}) plus min_new_tokens "
f"({min_new_tokens}) must be <= {max_model_len}")

max_new_tokens: Optional[int] = sampling_params.max_tokens
max_is_token_limit = False
Expand All @@ -485,7 +492,8 @@ async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
responses: List[TokenizeResponse] = []

#TODO maybe parallelize, also move convert_ids_to_tokens into the other threads
#TODO maybe parallelize, also move convert_ids_to_tokens
# into the other threads
for req in request.requests:
token_ids = await self.tokenizer_group.encode_async(req.text)
responses.append(
Expand Down Expand Up @@ -524,7 +532,8 @@ async def start_grpc_server(engine: AsyncLLMEngine,
#TODO add reflection

# SERVICE_NAMES = (
# generation_pb2.DESCRIPTOR.services_by_name["GenerationService"].full_name,
# generation_pb2.DESCRIPTOR.services_by_name["GenerationService"]
# .full_name,
# reflection.SERVICE_NAME,
# )
# reflection.enable_server_reflection(SERVICE_NAMES, server)
Expand Down Expand Up @@ -570,4 +579,4 @@ async def start_grpc_server(engine: AsyncLLMEngine,
await server.start()
logger.info(f"gRPC Server started at {listen_on}")

return server
return server
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
from contextlib import asynccontextmanager
import os
Expand Down Expand Up @@ -67,7 +68,6 @@ def parse_args():
return parsed_args



# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.entrypoints.openai.serving_engine import LoRA
from vllm.tgis_utils.args import EnvVarArgumentParser


class LoRAParserAction(argparse.Action):

def __call__(self, parser, namespace, values, option_string=None):
Expand Down
11 changes: 7 additions & 4 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
args.model = args.model_name
if args.max_sequence_length is not None:
if args.max_model_len not in (None, args.max_sequence_length):
raise ValueError("Inconsistent max_model_len and max_sequence_length arg values")
raise ValueError(
"Inconsistent max_model_len and max_sequence_length arg values"
)
args.max_model_len = args.max_sequence_length
if args.dtype_str is not None:
if args.dtype not in (None, 'auto', args.dtype_str):
Expand All @@ -96,11 +98,12 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
if args.num_gpus is not None and args.num_shard is not None \
and args.num_gpus != args.num_shard:
raise ValueError("Inconsistent num_gpus and num_shard arg values")
num_gpus = args.num_gpus if args.num_gpus is not None else args.num_shard
num_gpus = args.num_gpus if (args.num_gpus
is not None) else args.num_shard
if args.tensor_parallel_size not in [None, 1, num_gpus]:
raise ValueError(
"Inconsistent tensor_parallel_size and num_gpus/num_shard arg values"
)
"Inconsistent tensor_parallel_size and num_gpus/num_shard "
"arg values")
args.tensor_parallel_size = num_gpus

return args
16 changes: 1 addition & 15 deletions vllm/tgis_utils/logits_processors.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
from typing import List, Union
from typing import List

import torch
from transformers.generation.logits_process import TypicalLogitsWarper


class MinTokensLogitsProcessor:

def __init__(self, min_tokens: int, eos_token_id: Union[int, List[int]]):
self.min_tokens = min_tokens
self.eos_token_ids = torch.tensor(eos_token_id)

def __call__(self, token_ids: List[int],
logits: torch.tensor) -> torch.tensor:
# token_ids is only output tokens
if len(token_ids) < self.min_tokens:
logits[self.eos_token_ids] = -float("inf")
return logits


class TypicalLogitsWarperWrapper:

def __init__(self, mass: float):
Expand Down

0 comments on commit f9fad31

Please sign in to comment.