Skip to content

Commit

Permalink
Also exporting build_transformers_prefix_allowed_tokens_fn() for tran…
Browse files Browse the repository at this point in the history
…sformers integration
  • Loading branch information
noamgat committed Oct 11, 2023
1 parent be4b2cc commit e1fa09d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
7 changes: 4 additions & 3 deletions lmformatenforcer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
'RegexParser',
'JsonSchemaParser',
'TokenEnforcer',
'generate_enforced']
'generate_enforced',
'build_transformers_prefix_allowed_tokens_fn']

from .characterlevelparser import CharacterLevelParser, StringParser
from .regexparser import RegexParser
from .jsonschemaparser import JsonSchemaParser
from .tokenenforcer import TokenEnforcer

try:
from .transformerenforcer import generate_enforced
from .transformerenforcer import generate_enforced, build_transformers_prefix_allowed_tokens_fn
except Exception as e:
import logging
logging.warning(f"Could not import generate_enforced(). Transformers-based functionality will not be available. Details: {e}")
logging.warning(f"Could not import transformerenforcer. Transformers-based functionality will not be available. Details: {e}")
25 changes: 19 additions & 6 deletions lmformatenforcer/transformerenforcer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import Callable, List, Tuple, Union
from transformers import AutoModelForCausalLM

from .characterlevelparser import CharacterLevelParser
Expand Down Expand Up @@ -70,18 +70,31 @@ def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple
return regular_tokens


def generate_enforced(model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerBase,
character_level_parser: CharacterLevelParser,
**kwargs: dict) -> Union[str, dict]:

def build_transformers_prefix_allowed_tokens_fn(tokenizer: PreTrainedTokenizerBase,
character_level_parser: CharacterLevelParser) -> Callable[[int, torch.Tensor], List[int]]:
"""Build the prefix allowed tokens function that transformers will use to filter the tokens generated by the model. The result
can be passed to the prefix_allowed_tokens_fn parameter of the generate() method of transformers models or pipeline configurations."""
regular_tokens = _build_regular_tokens_list(tokenizer)
token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id)

def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> List[int]:
token_sequence = sent.tolist()
return token_enforcer.get_allowed_tokens(token_sequence)

return transformers_filter_allowed_tokens


def generate_enforced(model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerBase,
character_level_parser: CharacterLevelParser,
**kwargs: dict) -> Union[str, dict]:
"""Generate text from a model while enforcing a given format, generating enforcing diagnostic information.
This can be used instead of calling model.generate().
If return_dict_in_generate and output_scores parameters are True, diagnostic information will be returned in the result.
If you don't need this, consider using prefix_allowed_tokens_fn + build_transformers_prefix_allowed_tokens_fn() instead"""

transformers_filter_allowed_tokens = build_transformers_prefix_allowed_tokens_fn(tokenizer, character_level_parser)

is_multi_inputs = kwargs['input_ids'].shape[0] > 1
is_multi_beams = kwargs.get('num_beams', 1) > 1
support_diagnostics = not (is_multi_inputs or is_multi_beams) # TODO: Support diagnostics in these cases as well.
Expand Down

0 comments on commit e1fa09d

Please sign in to comment.