From e1fa09d8babe8efed4bc823e225724f457538913 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Wed, 11 Oct 2023 13:10:29 +0300 Subject: [PATCH] Also exporting build_transformers_prefix_allowed_tokens_fn() for transformers integration --- lmformatenforcer/__init__.py | 7 ++++--- lmformatenforcer/transformerenforcer.py | 25 +++++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/lmformatenforcer/__init__.py b/lmformatenforcer/__init__.py index 6d5f840..0221440 100644 --- a/lmformatenforcer/__init__.py +++ b/lmformatenforcer/__init__.py @@ -3,7 +3,8 @@ 'RegexParser', 'JsonSchemaParser', 'TokenEnforcer', - 'generate_enforced'] + 'generate_enforced', + 'build_transformers_prefix_allowed_tokens_fn'] from .characterlevelparser import CharacterLevelParser, StringParser from .regexparser import RegexParser @@ -11,7 +12,7 @@ from .tokenenforcer import TokenEnforcer try: - from .transformerenforcer import generate_enforced + from .transformerenforcer import generate_enforced, build_transformers_prefix_allowed_tokens_fn except Exception as e: import logging - logging.warning(f"Could not import generate_enforced(). Transformers-based functionality will not be available. Details: {e}") + logging.warning(f"Could not import transformerenforcer. Transformers-based functionality will not be available. Details: {e}") diff --git a/lmformatenforcer/transformerenforcer.py b/lmformatenforcer/transformerenforcer.py index d798757..e0ebdfb 100644 --- a/lmformatenforcer/transformerenforcer.py +++ b/lmformatenforcer/transformerenforcer.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import Callable, List, Tuple, Union from transformers import AutoModelForCausalLM from .characterlevelparser import CharacterLevelParser @@ -70,11 +70,10 @@ def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple return regular_tokens -def generate_enforced(model: AutoModelForCausalLM, - tokenizer: PreTrainedTokenizerBase, - character_level_parser: CharacterLevelParser, - **kwargs: dict) -> Union[str, dict]: - +def build_transformers_prefix_allowed_tokens_fn(tokenizer: PreTrainedTokenizerBase, + character_level_parser: CharacterLevelParser) -> Callable[[int, torch.Tensor], List[int]]: + """Build the prefix allowed tokens function that transformers will use to filter the tokens generated by the model. The result + can be passed to the prefix_allowed_tokens_fn parameter of the generate() method of transformers models or pipeline configurations.""" regular_tokens = _build_regular_tokens_list(tokenizer) token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id) @@ -82,6 +81,20 @@ def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> Lis token_sequence = sent.tolist() return token_enforcer.get_allowed_tokens(token_sequence) + return transformers_filter_allowed_tokens + + +def generate_enforced(model: AutoModelForCausalLM, + tokenizer: PreTrainedTokenizerBase, + character_level_parser: CharacterLevelParser, + **kwargs: dict) -> Union[str, dict]: + """Generate text from a model while enforcing a given format, generating enforcing diagnostic information. + This can be used instead of calling model.generate(). + If return_dict_in_generate and output_scores parameters are True, diagnostic information will be returned in the result. + If you don't need this, consider using prefix_allowed_tokens_fn + build_transformers_prefix_allowed_tokens_fn() instead""" + + transformers_filter_allowed_tokens = build_transformers_prefix_allowed_tokens_fn(tokenizer, character_level_parser) + is_multi_inputs = kwargs['input_ids'].shape[0] > 1 is_multi_beams = kwargs.get('num_beams', 1) > 1 support_diagnostics = not (is_multi_inputs or is_multi_beams) # TODO: Support diagnostics in these cases as well.