Skip to content

Commit

Permalink
Add guided decoding to TGIS gRPC API
Browse files Browse the repository at this point in the history
  enum ResponseFormat {
    // Plain text, no constraints
    TEXT = 0;
    // Valid json
    JSON = 1;
  }

  message StringChoices {
    repeated string choices = 1;
  }

  // Mutually-exclusive guided decoding options
  oneof guided {
    // Output will be in the specified format
    ResponseFormat format = 3;
    // Output will follow the provided JSON schema
    string json_schema = 4;
    // Output will follow the provided regex pattern
    string regex = 5;
    // Output will be exactly one of the specified choices
    StringChoices choice = 6;
    // Output will follow the provided context free grammar
    string grammar = 7;
  }

Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill committed May 22, 2024
1 parent 066041a commit 2e49d28
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 1 deletion.
25 changes: 25 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ message DecodingParameters {
// Exponentially increases the score of the EOS token
// once start_index tokens have been generated
optional LengthPenalty length_penalty = 2;

enum ResponseFormat {
// Plain text, no constraints
TEXT = 0;
// Valid json
JSON = 1;
}

message StringChoices {
repeated string choices = 1;
}

// Mutually-exclusive guided decoding options
oneof guided {
// Output will be in the specified format
ResponseFormat format = 3;
// Output will follow the provided JSON schema
string json_schema = 4;
// Output will follow the provided regex pattern
string regex = 5;
// Output will be exactly one of the specified choices
StringChoices choice = 6;
// Output will follow the provided context free grammar
string grammar = 7;
}
}


Expand Down
11 changes: 10 additions & 1 deletion vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from vllm.logger import init_logger
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
TypicalLogitsWarperWrapper)
from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics,
Expand Down Expand Up @@ -118,7 +120,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer_group = self.engine.engine.tokenizer
self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
Expand Down Expand Up @@ -389,6 +392,12 @@ async def _validate_and_convert_params(
ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple,
eos_token_id=self.tokenizer.eos_token_id))

guided_decode_logit_processor = (
await get_outlines_guided_decoding_logits_processor(decoding,
self.tokenizer))
if guided_decode_logit_processor is not None:
logits_processors.append(guided_decode_logit_processor)

time_limit_millis = stopping.time_limit_millis
deadline = time.time(
) + time_limit_millis / 1000.0 if time_limit_millis > 0 else None
Expand Down
69 changes: 69 additions & 0 deletions vllm/tgis_utils/guided_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import concurrent.futures
from copy import copy
from re import escape as regex_escape
from typing import Tuple, Union

import vllm.model_executor.guided_decoding.outlines_decoding as outlines_decoding # noqa: E501
from vllm.entrypoints.grpc.pb.generation_pb2 import DecodingParameters
from vllm.model_executor.guided_decoding.outlines_decoding import (
GuidedDecodingMode, _get_cached_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)


async def get_outlines_guided_decoding_logits_processor(
decoding_params: DecodingParameters,
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(decoding_params)
if not guide:
return None

if outlines_decoding.global_thread_pool is None:
outlines_decoding.global_thread_pool = (
concurrent.futures.ThreadPoolExecutor(max_workers=2))
loop = asyncio.get_running_loop()

result = await loop.run_in_executor(
outlines_decoding.global_thread_pool,
_get_cached_logits_processor,
guide,
tokenizer,
mode,
None, # guided_whitespace_pattern - TBD
)

logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor


def _get_guide_and_mode(
decoding_params: DecodingParameters,
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
guided = decoding_params.WhichOneof("guided")
if guided is not None:
if guided == "json_schema":
return decoding_params.json_schema, GuidedDecodingMode.JSON
if guided == "regex":
return decoding_params.regex, GuidedDecodingMode.REGEX
if guided == "choice":
choice_list = decoding_params.choice.choices
if len(choice_list) < 2:
raise ValueError("Must provide at least two choices")
# choice just uses regex
choices = [regex_escape(str(choice)) for choice in choice_list]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
if guided == "grammar":
return decoding_params.grammar, GuidedDecodingMode.GRAMMAR
if decoding_params.format == DecodingParameters.JSON:
return outlines_decoding.JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
return None, None

0 comments on commit 2e49d28

Please sign in to comment.