Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add guided decoding to TGIS gRPC API #31

Merged
merged 1 commit into from
May 30, 2024
Merged

Add guided decoding to TGIS gRPC API #31

merged 1 commit into from
May 30, 2024

Conversation

njhill
Copy link
Contributor

@njhill njhill commented May 22, 2024

Within the existing decoding request parameter section:

enum ResponseFormat {
  // Plain text, no constraints
  TEXT = 0;
  // Valid json
  JSON = 1;
}

message StringChoices {
  repeated string choices = 1;
}

// Mutually-exclusive guided decoding options
oneof guided {
  // Output will be in the specified format
  ResponseFormat format = 3;
  // Output will follow the provided JSON schema
  string json_schema = 4;
  // Output will follow the provided regex pattern
  string regex = 5;
  // Output will be exactly one of the specified choices
  StringChoices choice = 6;
  // Output will follow the provided context free grammar
  string grammar = 7;
}

// Output will follow the provided regex pattern
string regex = 5;
// Output will be exactly one of the specified choices
StringChoices choice = 6;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately you cannot have repeated fields directly within oneofs :(

protocolbuffers/protobuf#2592 (comment)

@njhill njhill force-pushed the guided branch 3 times, most recently from 384d566 to f9ee133 Compare May 22, 2024 16:21
  enum ResponseFormat {
    // Plain text, no constraints
    TEXT = 0;
    // Valid json
    JSON = 1;
  }

  message StringChoices {
    repeated string choices = 1;
  }

  // Mutually-exclusive guided decoding options
  oneof guided {
    // Output will be in the specified format
    ResponseFormat format = 3;
    // Output will follow the provided JSON schema
    string json_schema = 4;
    // Output will follow the provided regex pattern
    string regex = 5;
    // Output will be exactly one of the specified choices
    StringChoices choice = 6;
    // Output will follow the provided context free grammar
    string grammar = 7;
  }

Signed-off-by: Nick Hill <[email protected]>

if outlines_decoding.global_thread_pool is None:
outlines_decoding.global_thread_pool = (
concurrent.futures.ThreadPoolExecutor(max_workers=2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked much at logits processors, why does this require its own thread pool?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same code as here:

global_thread_pool = concurrent.futures.ThreadPoolExecutor(
. If I'm not mistaken, only the construction of the logits processor happens in another thread. But if the logits processor is cached, I'm not sure what's the benefit of having another thread build the object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right. The code is just the same as that in the http API. It's dispatched to a threadpool to avoid blocking the asyncio event loop, but I think it could be made more efficient since we only care about this in the case that the LP is not already cached. In any case we can fix that as a follow-on since we need to fix that related concurrency bug anyhow.

@@ -118,7 +120,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer_group = self.engine.engine.tokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen versions of the code where the get_tokenizer_group function exists and others where it doesn't. What's happening with this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxdebayser that's from this upstream PR vllm-project/vllm#3512

It didn't get merged in a timely manner and is now buried in conflicts :(

Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the bug reported in issue https://github.ibm.com/ai-foundation/fmaas-inference-server/issues/718 is not cause by the code in this PR, I think we can merge it and fix the problem in another PR.

@njhill njhill merged commit 3dc2819 into main May 30, 2024
14 checks passed
@njhill njhill deleted the guided branch May 30, 2024 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants