Skip to content

Commit

Permalink
Temporary fix for input truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Mar 28, 2024
1 parent 029c685 commit cde0474
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
21 changes: 17 additions & 4 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,26 @@ async def _validate_prompt_and_tokenize(
prompt: Optional[str],
context: ServicerContext,
) -> Tuple[List[int], bool]:
tokenize_kwargs = {"truncation": True,
"max_length": truncate_input_tokens} \
if truncate_input_tokens is not None else {}

max_model_len = self.config.max_model_len
# tokenize_kwargs = {"truncation": True,
# "max_length": truncate_input_tokens} \
# if truncate_input_tokens is not None else {
# "truncation": True, "max_length": max_model_len + 1}
tokenize_kwargs = {}

input_ids = await self.tokenizer_group.encode_async(
prompt, **tokenize_kwargs)

#TODO this is temporary until truncation option is added
# to the TokenizerGroup encode methods
if truncate_input_tokens and truncate_input_tokens < len(input_ids):
input_ids = input_ids[-truncate_input_tokens:]
if not sampling_params.skip_special_tokens:
add_bos_token = getattr(self.tokenizer, "add_bos_token", False)
if add_bos_token:
input_ids[0] = self.tokenizer.bos_token_id
# -----------------------------------------------

token_num = len(input_ids)

if token_num >= max_model_len:
Expand Down
3 changes: 2 additions & 1 deletion vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
args.dtype = args.dtype_str
if args.quantize:
if args.quantization and args.quantization != args.quantize:
raise ValueError("Inconsistent quantize and quantization arg values")
raise ValueError(
"Inconsistent quantize and quantization arg values")
args.quantization = args.quantize
if args.num_gpus is not None or args.num_shard is not None:
if args.num_gpus is not None and args.num_shard is not None \
Expand Down

0 comments on commit cde0474

Please sign in to comment.