Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add guided decoding to TGIS gRPC API #31

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ message DecodingParameters {
// Exponentially increases the score of the EOS token
// once start_index tokens have been generated
optional LengthPenalty length_penalty = 2;

enum ResponseFormat {
// Plain text, no constraints
TEXT = 0;
// Valid json
JSON = 1;
}

message StringChoices {
repeated string choices = 1;
}

// Mutually-exclusive guided decoding options
oneof guided {
// Output will be in the specified format
ResponseFormat format = 3;
// Output will follow the provided JSON schema
string json_schema = 4;
// Output will follow the provided regex pattern
string regex = 5;
// Output will be exactly one of the specified choices
StringChoices choice = 6;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately you cannot have repeated fields directly within oneofs :(

protocolbuffers/protobuf#2592 (comment)

// Output will follow the provided context free grammar
string grammar = 7;
}
}


Expand Down
11 changes: 10 additions & 1 deletion vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from vllm.logger import init_logger
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
TypicalLogitsWarperWrapper)
from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics,
Expand Down Expand Up @@ -118,7 +120,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer_group = self.engine.engine.tokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen versions of the code where the get_tokenizer_group function exists and others where it doesn't. What's happening with this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxdebayser that's from this upstream PR vllm-project/vllm#3512

It didn't get merged in a timely manner and is now buried in conflicts :(

self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
Expand Down Expand Up @@ -389,6 +392,12 @@ async def _validate_and_convert_params(
ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple,
eos_token_id=self.tokenizer.eos_token_id))

guided_decode_logit_processor = (
await get_outlines_guided_decoding_logits_processor(decoding,
self.tokenizer))
if guided_decode_logit_processor is not None:
logits_processors.append(guided_decode_logit_processor)

time_limit_millis = stopping.time_limit_millis
deadline = time.time(
) + time_limit_millis / 1000.0 if time_limit_millis > 0 else None
Expand Down
69 changes: 69 additions & 0 deletions vllm/tgis_utils/guided_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import concurrent.futures
from copy import copy
from re import escape as regex_escape
from typing import Tuple, Union

import vllm.model_executor.guided_decoding.outlines_decoding as outlines_decoding # noqa: E501
from vllm.entrypoints.grpc.pb.generation_pb2 import DecodingParameters
from vllm.model_executor.guided_decoding.outlines_decoding import (
GuidedDecodingMode, _get_cached_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)


async def get_outlines_guided_decoding_logits_processor(
decoding_params: DecodingParameters,
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(decoding_params)
if not guide:
return None

if outlines_decoding.global_thread_pool is None:
outlines_decoding.global_thread_pool = (
concurrent.futures.ThreadPoolExecutor(max_workers=2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked much at logits processors, why does this require its own thread pool?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same code as here:

global_thread_pool = concurrent.futures.ThreadPoolExecutor(
. If I'm not mistaken, only the construction of the logits processor happens in another thread. But if the logits processor is cached, I'm not sure what's the benefit of having another thread build the object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right. The code is just the same as that in the http API. It's dispatched to a threadpool to avoid blocking the asyncio event loop, but I think it could be made more efficient since we only care about this in the case that the LP is not already cached. In any case we can fix that as a follow-on since we need to fix that related concurrency bug anyhow.

loop = asyncio.get_running_loop()

result = await loop.run_in_executor(
outlines_decoding.global_thread_pool,
_get_cached_logits_processor,
guide,
tokenizer,
mode,
None, # guided_whitespace_pattern - TBD
)

logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor


def _get_guide_and_mode(
decoding_params: DecodingParameters,
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
guided = decoding_params.WhichOneof("guided")
if guided is not None:
if guided == "json_schema":
return decoding_params.json_schema, GuidedDecodingMode.JSON
if guided == "regex":
return decoding_params.regex, GuidedDecodingMode.REGEX
if guided == "choice":
choice_list = decoding_params.choice.choices
if len(choice_list) < 2:
raise ValueError("Must provide at least two choices")
# choice just uses regex
choices = [regex_escape(str(choice)) for choice in choice_list]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
if guided == "grammar":
return decoding_params.grammar, GuidedDecodingMode.GRAMMAR
if decoding_params.format == DecodingParameters.JSON:
return outlines_decoding.JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
return None, None
Loading