-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add guided decoding to TGIS gRPC API #31
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,8 @@ | |
from vllm.logger import init_logger | ||
from vllm.sequence import Logprob | ||
from vllm.tgis_utils import logs | ||
from vllm.tgis_utils.guided_decoding import ( | ||
get_outlines_guided_decoding_logits_processor) | ||
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper, | ||
TypicalLogitsWarperWrapper) | ||
from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics, | ||
|
@@ -118,7 +120,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace): | |
|
||
async def _post_init(self): | ||
self.config = await self.engine.get_model_config() | ||
self.tokenizer_group = await self.engine.get_tokenizer_group() | ||
# self.tokenizer_group = await self.engine.get_tokenizer_group() | ||
self.tokenizer_group = self.engine.engine.tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've seen versions of the code where the get_tokenizer_group function exists and others where it doesn't. What's happening with this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maxdebayser that's from this upstream PR vllm-project/vllm#3512 It didn't get merged in a timely manner and is now buried in conflicts :( |
||
self.tokenizer = await self.engine.get_tokenizer() | ||
|
||
# Swap in the special TGIS stats logger | ||
|
@@ -389,6 +392,12 @@ async def _validate_and_convert_params( | |
ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple, | ||
eos_token_id=self.tokenizer.eos_token_id)) | ||
|
||
guided_decode_logit_processor = ( | ||
await get_outlines_guided_decoding_logits_processor(decoding, | ||
self.tokenizer)) | ||
if guided_decode_logit_processor is not None: | ||
logits_processors.append(guided_decode_logit_processor) | ||
|
||
time_limit_millis = stopping.time_limit_millis | ||
deadline = time.time( | ||
) + time_limit_millis / 1000.0 if time_limit_millis > 0 else None | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,69 @@ | ||||
import asyncio | ||||
import concurrent.futures | ||||
from copy import copy | ||||
from re import escape as regex_escape | ||||
from typing import Tuple, Union | ||||
|
||||
import vllm.model_executor.guided_decoding.outlines_decoding as outlines_decoding # noqa: E501 | ||||
from vllm.entrypoints.grpc.pb.generation_pb2 import DecodingParameters | ||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( | ||||
GuidedDecodingMode, _get_cached_logits_processor) | ||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( | ||||
JSONLogitsProcessor, RegexLogitsProcessor) | ||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor( | ||||
decoding_params: DecodingParameters, | ||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: | ||||
""" | ||||
Check for guided decoding parameters | ||||
and get the necessary logits processor for the given guide. | ||||
We cache logit processors by (guide, tokenizer), and on cache hit | ||||
we make a shallow copy to reuse the same underlying FSM. | ||||
""" | ||||
guide, mode = _get_guide_and_mode(decoding_params) | ||||
if not guide: | ||||
return None | ||||
|
||||
if outlines_decoding.global_thread_pool is None: | ||||
outlines_decoding.global_thread_pool = ( | ||||
concurrent.futures.ThreadPoolExecutor(max_workers=2)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't looked much at logits processors, why does this require its own thread pool? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the same code as here:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that's right. The code is just the same as that in the http API. It's dispatched to a threadpool to avoid blocking the asyncio event loop, but I think it could be made more efficient since we only care about this in the case that the LP is not already cached. In any case we can fix that as a follow-on since we need to fix that related concurrency bug anyhow. |
||||
loop = asyncio.get_running_loop() | ||||
|
||||
result = await loop.run_in_executor( | ||||
outlines_decoding.global_thread_pool, | ||||
_get_cached_logits_processor, | ||||
guide, | ||||
tokenizer, | ||||
mode, | ||||
None, # guided_whitespace_pattern - TBD | ||||
) | ||||
|
||||
logits_processor = copy(result) | ||||
# reset logits processor's internal state | ||||
logits_processor.init_state() | ||||
return logits_processor | ||||
|
||||
|
||||
def _get_guide_and_mode( | ||||
decoding_params: DecodingParameters, | ||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: | ||||
guided = decoding_params.WhichOneof("guided") | ||||
if guided is not None: | ||||
if guided == "json_schema": | ||||
return decoding_params.json_schema, GuidedDecodingMode.JSON | ||||
if guided == "regex": | ||||
return decoding_params.regex, GuidedDecodingMode.REGEX | ||||
if guided == "choice": | ||||
choice_list = decoding_params.choice.choices | ||||
if len(choice_list) < 2: | ||||
raise ValueError("Must provide at least two choices") | ||||
# choice just uses regex | ||||
choices = [regex_escape(str(choice)) for choice in choice_list] | ||||
choices_regex = "(" + "|".join(choices) + ")" | ||||
return choices_regex, GuidedDecodingMode.CHOICE | ||||
if guided == "grammar": | ||||
return decoding_params.grammar, GuidedDecodingMode.GRAMMAR | ||||
if decoding_params.format == DecodingParameters.JSON: | ||||
return outlines_decoding.JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR | ||||
return None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately you cannot have
repeated
fields directly withinoneof
s :(protocolbuffers/protobuf#2592 (comment)