From 2f3d05be60c5f6414f5843c09f6e631a3213ac6c Mon Sep 17 00:00:00 2001 From: Josh Date: Wed, 10 Apr 2024 14:26:15 +0100 Subject: [PATCH 1/7] Allow invalid state regexes to exit rather than raise no allowed tokens error --- lmformatenforcer/regexparser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmformatenforcer/regexparser.py b/lmformatenforcer/regexparser.py index c782531..ed134b3 100644 --- a/lmformatenforcer/regexparser.py +++ b/lmformatenforcer/regexparser.py @@ -51,7 +51,7 @@ def add_character(self, new_character: str) -> 'RegexParser': return RegexParser(self.context, self.config, RegexParser.INVALID_STATE) def can_end(self) -> bool: - return self.current_state in self.context.pattern.finals + return self.current_state in self.context.pattern.finals or self.current_state == RegexParser.INVALID_STATE def get_allowed_characters(self) -> str: if self.current_state not in self.context.pattern.map: From 3422f7bc785332dd0878ed7a90b5b593faf8f8cb Mon Sep 17 00:00:00 2001 From: Josh C <32071009+JoshC8C7@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:00:45 +0000 Subject: [PATCH 2/7] Cache key with superfast3 (#2) * superfast mode * sadness * sadness2 * sadness3 * clear prints * debug print * delete varcheck * polish L patch --- lmformatenforcer/jsonschemaparser.py | 23 ++++++++++++++++++++++- lmformatenforcer/tokenenforcer.py | 19 ++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index d5d67a9..5ab904e 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -1,3 +1,4 @@ +import os from copy import deepcopy import enum import sys @@ -45,7 +46,9 @@ def __init__(self, self.context.model_class = JsonSchemaObject(**json_schema) self.context.active_parser = self self.context.alphabet_without_quotes = self.config.alphabet.replace('"', '') - + + + self.num_consecutive_whitespaces = num_consecutive_whitespaces if existing_stack is None: self.object_stack = [get_parser(self, self.context.model_class)] @@ -153,6 +156,24 @@ def shortcut_key(self) -> Optional[Hashable]: return ('json_freetext', cur_len, min_len, max_len) return None + def cache_key(self) -> Optional[Hashable]: + if self.object_stack: + current_parser = self.object_stack[-1] + if isinstance(current_parser, StringParsingState): + if not current_parser.allowed_strings and not current_parser.seen_opening_quote and not current_parser.regex_parser: + + if os.getenv("SUPERFAST_MODE", "0") in ["1", "true", "True"]: + return "superfast" + + cur_len = len(current_parser.parsed_string) + assert cur_len == 0 + min_len = current_parser.min_length or 0 + max_len = current_parser.max_length or sys.maxsize + qx = self.add_character('"').add_character('"').get_allowed_characters() + + return ("open_value", min_len, max_len, self.num_consecutive_whitespaces, qx) + + class BaseParsingState(CharacterLevelParser): def __init__(self, root: JsonSchemaParser): diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index 6b2534e..256a564 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass, field import sys from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union @@ -51,7 +52,12 @@ def __init__(self, tokenizer_data: TokenEnforcerTokenizerData, parser: Character self.eos_token_id = tokenizer_data.eos_token_id self.regular_tokens = tokenizer_data.regular_tokens self.allowed_token_cache: Dict[Hashable, List[int]] = {} - + + if os.getenv("SUPERFAST_MODE", "0") in ["1", "true", "True"]: + for token_idx, decoded, _ in self.regular_tokens: + if decoded == '"': + self.allowed_token_cache["superfast"] = [token_idx] + config = CharacterLevelParserConfig(alphabet=tokenizer_data.tokenizer_alphabet) parser.config = config @@ -85,6 +91,8 @@ def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: self._compute_allowed_tokens(sent_tuple, new_state) return new_state.allowed_tokens + + def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.OutputTensorState'): try: allowed_tokens: List[int] = [] @@ -123,7 +131,7 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token relevant_characters = tree_node.children.keys() # This next line is the heart of the traversal algorithm. We only explore paths that are shared by both the parser and the tokenizer. characters_to_explore = set(relevant_characters).intersection(allowed_characters) - + # Performance optimization: If we are in JSON freetext, all of the tokens that don't contain quote, or end with quote, are legal, so we take # their cached list. If the quote character is allowed, we only need to dynamically explore the cases where the string starts with a quote. # This breaks the elegance of the API, but otherwise it is a huge performance hit. @@ -137,7 +145,6 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token allowed_tokens.extend(cache.lookup_allowed_tokens(min_remaining, max_allowed_len)) characters_to_explore = characters_to_explore.intersection(['"']) - for character in characters_to_explore: next_parser = parser.add_character(character) next_tree_node = tree_node.children[character] @@ -154,6 +161,12 @@ def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_ prev_decoded = self.decoder(state.current_word_tokens) new_decoded = self.decoder(new_state.current_word_tokens) new_characters = new_decoded[len(prev_decoded):] + + if len(new_characters) == 1 and self.tokenizer_tree.tokens_to_strs.get(token_sequence[-2]) == '�' and self.tokenizer_tree.tokens_to_strs[new_token] == '�': + print("TRIGGERED") + decoded_unicode_char = self.decoder(token_sequence[-2:]) + new_characters = 'X'*len(decoded_unicode_char) + for character in new_characters: try: new_state.parser = new_state.parser.add_character(character) From cf2f02aed2444a2521f7efac172cba4a6176fca1 Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 28 Nov 2024 15:13:06 +0000 Subject: [PATCH 3/7] change superfast mode to be a config field, not an env var --- lmformatenforcer/characterlevelparser.py | 5 +++++ lmformatenforcer/jsonschemaparser.py | 2 +- lmformatenforcer/tokenenforcer.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lmformatenforcer/characterlevelparser.py b/lmformatenforcer/characterlevelparser.py index 186eab9..a31ef66 100644 --- a/lmformatenforcer/characterlevelparser.py +++ b/lmformatenforcer/characterlevelparser.py @@ -35,6 +35,11 @@ class CharacterLevelParserConfig: """What is the maximum json array length if not specified by the schema. Helps the LLM avoid infinite loops.""" + superfast_mode: bool = False + """ + Whether to skip calculations on acceptable tokens when starting a new string, in favour of always outputting a speech mark. + """ + class CharacterLevelParser(abc.ABC): """CharacterLevelParser is an interface for classes that can parse strings one character at a time, and determine which characters are allowed at any specific time""" diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 5ab904e..a761518 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -162,7 +162,7 @@ def cache_key(self) -> Optional[Hashable]: if isinstance(current_parser, StringParsingState): if not current_parser.allowed_strings and not current_parser.seen_opening_quote and not current_parser.regex_parser: - if os.getenv("SUPERFAST_MODE", "0") in ["1", "true", "True"]: + if self.config.superfast_mode: return "superfast" cur_len = len(current_parser.parsed_string) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index 256a564..c62c45c 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -53,7 +53,7 @@ def __init__(self, tokenizer_data: TokenEnforcerTokenizerData, parser: Character self.regular_tokens = tokenizer_data.regular_tokens self.allowed_token_cache: Dict[Hashable, List[int]] = {} - if os.getenv("SUPERFAST_MODE", "0") in ["1", "true", "True"]: + if parser.config.superfast_mode: for token_idx, decoded, _ in self.regular_tokens: if decoded == '"': self.allowed_token_cache["superfast"] = [token_idx] From e47d887c5654661e30efbd23f2b63a549050f712 Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 28 Nov 2024 15:46:12 +0000 Subject: [PATCH 4/7] print config --- lmformatenforcer/jsonschemaparser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index a761518..3bf049b 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -131,7 +131,8 @@ def get_allowed_characters(self) -> str: # continuation tokens than there are beams. Therefore, we allow whitespace # characters when the object stack is empty (= we are done parsing) allowed_characters = WHITESPACE_CHARACTERS - + + print("MAXCONSWS: ", self.config) if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces: # print("Filtering whitespace characters") allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS) From 764e05c1590925557a89572a5008ce82038f335e Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 28 Nov 2024 15:56:56 +0000 Subject: [PATCH 5/7] print config2 --- lmformatenforcer/jsonschemaparser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 3bf049b..0692faf 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -132,7 +132,7 @@ def get_allowed_characters(self) -> str: # characters when the object stack is empty (= we are done parsing) allowed_characters = WHITESPACE_CHARACTERS - print("MAXCONSWS: ", self.config) + print("MAXCONSWS: ", self.config.max_consecutive_whitespaces) if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces: # print("Filtering whitespace characters") allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS) From b6abea19caf0583586e9cef7fb4fb764a4951d64 Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 28 Nov 2024 16:06:21 +0000 Subject: [PATCH 6/7] actually use the config passed in to the token_enforcer --- lmformatenforcer/tokenenforcer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index c62c45c..61d3cff 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -58,8 +58,7 @@ def __init__(self, tokenizer_data: TokenEnforcerTokenizerData, parser: Character if decoded == '"': self.allowed_token_cache["superfast"] = [token_idx] - config = CharacterLevelParserConfig(alphabet=tokenizer_data.tokenizer_alphabet) - parser.config = config + parser.config.alphabet = tokenizer_data.tokenizer_alphabet def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: """ From 037fc46bdda0a3e04d05e8ef57e8cd60c499fc9e Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 28 Nov 2024 16:16:05 +0000 Subject: [PATCH 7/7] remove prints --- lmformatenforcer/jsonschemaparser.py | 1 - lmformatenforcer/tokenenforcer.py | 1 - 2 files changed, 2 deletions(-) diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 0692faf..a466964 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -132,7 +132,6 @@ def get_allowed_characters(self) -> str: # characters when the object stack is empty (= we are done parsing) allowed_characters = WHITESPACE_CHARACTERS - print("MAXCONSWS: ", self.config.max_consecutive_whitespaces) if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces: # print("Filtering whitespace characters") allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index 61d3cff..7c7beb0 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -162,7 +162,6 @@ def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_ new_characters = new_decoded[len(prev_decoded):] if len(new_characters) == 1 and self.tokenizer_tree.tokens_to_strs.get(token_sequence[-2]) == '�' and self.tokenizer_tree.tokens_to_strs[new_token] == '�': - print("TRIGGERED") decoded_unicode_char = self.decoder(token_sequence[-2:]) new_characters = 'X'*len(decoded_unicode_char)