diff --git a/lmformatenforcer/__init__.py b/lmformatenforcer/__init__.py index 596ee4c..6d5f840 100644 --- a/lmformatenforcer/__init__.py +++ b/lmformatenforcer/__init__.py @@ -1,12 +1,14 @@ __all__ = ['CharacterLevelParser', 'StringParser', 'RegexParser', - 'JsonSchemaParser', + 'JsonSchemaParser', + 'TokenEnforcer', 'generate_enforced'] from .characterlevelparser import CharacterLevelParser, StringParser from .regexparser import RegexParser from .jsonschemaparser import JsonSchemaParser +from .tokenenforcer import TokenEnforcer try: from .transformerenforcer import generate_enforced diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index 57b6408..081f9cf 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -1,13 +1,7 @@ from dataclasses import dataclass, field -from typing import Dict, Hashable, List, Optional, Union +from typing import Callable, Dict, Hashable, List, Optional, Tuple import logging - -from numpy import rec from .characterlevelparser import CharacterLevelParser, ForceStopParser -from .jsonschemaparser import JsonSchemaParser -from transformers.tokenization_utils import PreTrainedTokenizerBase -from .external.jsonschemaobject import JsonSchemaObject - from .tokenizerprefixtree import TokenizerPrefixTree, TokenizerPrefixTreeNode @@ -18,24 +12,22 @@ class OutputTensorState: parser: CharacterLevelParser allowed_tokens: List[int] = field(default_factory=list) - def __init__(self, tokenizer: PreTrainedTokenizerBase, parser: CharacterLevelParser): - self.tokenizer = tokenizer - self.token_0 = tokenizer.encode("0")[-1] + def __init__(self, regular_tokens: List[Tuple[int, str]], + parser: CharacterLevelParser, + decoder: Callable[[List[int]], str], + eos_token_id: int): self.prefix_states: Dict[Hashable, TokenEnforcer.OutputTensorState] = {} self.root_parser = parser - self.tokenizer_tree = TokenizerPrefixTree(tokenizer) - - def _decode_single_token(self, token: int) -> str: - # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. - decoded = self.tokenizer.decode([self.token_0, token])[1:] - return decoded + self.tokenizer_tree = TokenizerPrefixTree(regular_tokens) + self.decoder = decoder + self.eos_token_id = eos_token_id - def filter_allowed_tokens(self, batch_id: int, sent: 'torch.Tensor') -> List[int]: + def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: # In order to elegantly support beam search and batching, we don't store per-batch information. # Instead, we store a hash of all the states (unique token tensors) we encountered so far. # When we encounter a new unique token tensor, we find the token tensor that led to it, and continue from there. - sent_tuple = tuple(sent.tolist()) + sent_tuple = tuple(token_sequence) prev_step_tuple = sent_tuple[:-1] if sent_tuple in self.prefix_states: @@ -44,7 +36,7 @@ def filter_allowed_tokens(self, batch_id: int, sent: 'torch.Tensor') -> List[int elif prev_step_tuple not in self.prefix_states: # We have not encountered the tensor up to the before-last entry. This means that this is the first call - the instruction / prompt tensor. # Initialize the root node - state = TokenEnforcer.OutputTensorState(str_so_far=self.tokenizer.decode(sent), + state = TokenEnforcer.OutputTensorState(str_so_far=self.decoder(token_sequence), parser=self.root_parser) self.prefix_states[sent_tuple] = state self._compute_allowed_tokens(state) @@ -52,7 +44,7 @@ def filter_allowed_tokens(self, batch_id: int, sent: 'torch.Tensor') -> List[int else: # Find the state that led to this node. We explicitly don't use the concept of "timestep" because of beam search prev_step_state = self.prefix_states[prev_step_tuple] - new_state = self._apply_new_characters(prev_step_state, sent) + new_state = self._apply_new_characters(prev_step_state, token_sequence) self.prefix_states[sent_tuple] = new_state self._compute_allowed_tokens(new_state) return new_state.allowed_tokens @@ -62,7 +54,7 @@ def _compute_allowed_tokens(self, state: 'TokenEnforcer.OutputTensorState'): shortcut_key = state.parser.shortcut_key() self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key) if state.parser.can_end(): - allowed_tokens.append(self.tokenizer.eos_token_id) + allowed_tokens.append(self.eos_token_id) if not allowed_tokens: raise ValueError(f"Parser reached state with no allowed tokens") # root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) @@ -88,16 +80,16 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token next_tree_node = tree_node.children[character] self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None) - - def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', sent: 'torch.Tensor'): - characters = self.tokenizer.decode(sent) + def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_sequence: List[int]): + characters = self.decoder(token_sequence) new_state = TokenEnforcer.OutputTensorState(str_so_far=characters, parser=state.parser) new_characters = characters[len(state.str_so_far):] for character in new_characters: if character in new_state.parser.get_allowed_characters(): new_state.parser = new_state.parser.add_character(character) else: - logging.warning(f"Received an invalid character '{character}', switching to ForceStopParser") + # This can happen in beam / batch scenarios, when some of the batches finished but others are continuing. + logging.debug(f"Received an invalid character '{character}', switching to ForceStopParser") new_state.parser = ForceStopParser() return new_state diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index a73c413..aa075cb 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -1,21 +1,17 @@ -from typing import Dict, List -from transformers.tokenization_utils import PreTrainedTokenizerBase +from typing import Dict, List, Tuple + class TokenizerPrefixTreeNode: def __init__(self): self.tokens: List[int] = [] self.children: Dict[str, TokenizerPrefixTreeNode] = {} + class TokenizerPrefixTree: - def __init__(self, tokenizer: PreTrainedTokenizerBase): - self.tokenizer = tokenizer - self.token_0 = tokenizer.encode("0")[-1] + def __init__(self, regular_tokens: List[Tuple[int, str]]): self.root = TokenizerPrefixTreeNode() self.json_freetext_tokens: List[int] = [] - for token_idx in range(self.tokenizer.vocab_size): - if token_idx in self.tokenizer.all_special_ids: - continue - decoded = self._decode_single_token(token_idx) + for token_idx, decoded in regular_tokens: self._add_token_to_tree(decoded, token_idx, self.root) # Performance optimization - cache the tokens of all the strings that don't contain a quote in the middle. # When we are in a JSON freetext string field, they will all be permitted and this will save a lot of tree iterations. @@ -28,8 +24,3 @@ def _add_token_to_tree(self, token_str: str, token_idx: int, node: TokenizerPref node.children[character] = TokenizerPrefixTreeNode() node = node.children[character] node.tokens.append(token_idx) - - def _decode_single_token(self, token: int) -> str: - # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. - decoded = self.tokenizer.decode([self.token_0, token])[1:] - return decoded \ No newline at end of file diff --git a/lmformatenforcer/transformerenforcer.py b/lmformatenforcer/transformerenforcer.py index 602bab9..67b5630 100644 --- a/lmformatenforcer/transformerenforcer.py +++ b/lmformatenforcer/transformerenforcer.py @@ -58,16 +58,34 @@ def get_leading_scores(self) -> Tuple[List[int], List[float]]: return best_tokens.tolist(), token_probs_list +def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str]]: + token_0 = tokenizer.encode("0")[-1] + regular_tokens = [] + for token_idx in range(tokenizer.vocab_size): + if token_idx in tokenizer.all_special_ids: + continue + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + decoded = tokenizer.decode([token_0, token_idx])[1:] + regular_tokens.append((token_idx, decoded)) + return regular_tokens + + def generate_enforced(model: AutoModelForCausalLM, tokenizer: PreTrainedTokenizerBase, character_level_parser: CharacterLevelParser, **kwargs: dict) -> Union[str, dict]: - token_enforcer = TokenEnforcer(tokenizer, character_level_parser) + + regular_tokens = _build_regular_tokens_list(tokenizer) + token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id) + def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> List[int]: + token_sequence = sent.tolist() + return token_enforcer.get_allowed_tokens(token_sequence) + is_multi_inputs = kwargs['input_ids'].shape[0] > 1 is_multi_beams = kwargs.get('num_beams', 1) > 1 logits_saver = LogitsSaverManager(model) - logits_saver.replace_logits_warper(token_enforcer.filter_allowed_tokens) + logits_saver.replace_logits_warper(transformers_filter_allowed_tokens) generate_kwargs = kwargs return_dict_in_generate = kwargs.get('return_dict_in_generate', False) output_scores = kwargs.get('output_scores', None) diff --git a/samples/colab_llama2_enforcer.ipynb b/samples/colab_llama2_enforcer.ipynb index 1687ec7..16171fd 100644 --- a/samples/colab_llama2_enforcer.ipynb +++ b/samples/colab_llama2_enforcer.ipynb @@ -41,7 +41,7 @@ "# import sys\n", "# import os\n", "# sys.path.append(os.path.abspath('..'))\n", - "# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" + "## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" ] }, { @@ -55,16 +55,9 @@ "text": [ "/home/noamgat/mambaforge/envs/commentranker/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [03:44<00:00, 112.07s/it]\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [05:29<00:00, 164.58s/it]\n", "Using pad_token, but it is not set yet.\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OK\n" - ] } ], "source": [ @@ -136,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -226,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -297,22 +290,14 @@ "data": { "text/markdown": [ "```\n", - " Of course, I'd be happy to provide information about Michael Jordan using the provided JSON schema. Here's the response:\n", + " Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.\n", "{\n", "\"title\": \"AnswerFormat\",\n", "\"type\": \"object\",\n", "\"properties\": {\n", - "\"first_name\": {\n", - "\"title\": \"First Name\",\n", - "\"type\": \"string\",\n", - "\"required\": true\n", - "\n", - "},\n", - "\"last_name\": {\n", - "\n", - "\"title\": \"Last Name\",\n", - "\n", - "\"type\":\n", + "\"first_name\": {\"title\": \"First Name\", \"type\": \"string\"},\n", + "\"last_name\": {\"title\": \"Last Name\", \"type\": \"string\"},\n", + "\"year_of_birth\": {\"title\": \"Year Of Birth\", \"\n", "```" ], "text/plain": [ @@ -372,7 +357,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -961,7 +946,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1006,7 +991,7 @@ "data": { "text/markdown": [ "```\n", - " Thank you for your question! Michael Jordan was born in the year 1963.\n", + " Thank you for asking! Michael Jordan was born in the year 1963.\n", "```" ], "text/plain": [ @@ -1484,7 +1469,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ {