Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use config3 #150

Closed
wants to merge 10 commits into from
5 changes: 5 additions & 0 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class CharacterLevelParserConfig:
"""What is the maximum json array length if not specified by the schema. Helps the LLM
avoid infinite loops."""

superfast_mode: bool = False
"""
Whether to skip calculations on acceptable tokens when starting a new string, in favour of always outputting a speech mark.
"""


class CharacterLevelParser(abc.ABC):
"""CharacterLevelParser is an interface for classes that can parse strings one character at a time, and determine which characters are allowed at any specific time"""
Expand Down
25 changes: 23 additions & 2 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from copy import deepcopy
import enum
import sys
Expand Down Expand Up @@ -45,7 +46,9 @@ def __init__(self,
self.context.model_class = JsonSchemaObject(**json_schema)
self.context.active_parser = self
self.context.alphabet_without_quotes = self.config.alphabet.replace('"', '')




self.num_consecutive_whitespaces = num_consecutive_whitespaces
if existing_stack is None:
self.object_stack = [get_parser(self, self.context.model_class)]
Expand Down Expand Up @@ -128,7 +131,7 @@ def get_allowed_characters(self) -> str:
# continuation tokens than there are beams. Therefore, we allow whitespace
# characters when the object stack is empty (= we are done parsing)
allowed_characters = WHITESPACE_CHARACTERS

if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces:
# print("Filtering whitespace characters")
allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS)
Expand All @@ -153,6 +156,24 @@ def shortcut_key(self) -> Optional[Hashable]:
return ('json_freetext', cur_len, min_len, max_len)
return None

def cache_key(self) -> Optional[Hashable]:
if self.object_stack:
current_parser = self.object_stack[-1]
if isinstance(current_parser, StringParsingState):
if not current_parser.allowed_strings and not current_parser.seen_opening_quote and not current_parser.regex_parser:

if self.config.superfast_mode:
return "superfast"

cur_len = len(current_parser.parsed_string)
assert cur_len == 0
min_len = current_parser.min_length or 0
max_len = current_parser.max_length or sys.maxsize
qx = self.add_character('"').add_character('"').get_allowed_characters()

return ("open_value", min_len, max_len, self.num_consecutive_whitespaces, qx)



class BaseParsingState(CharacterLevelParser):
def __init__(self, root: JsonSchemaParser):
Expand Down
21 changes: 16 additions & 5 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass, field
import sys
from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -51,9 +52,13 @@ def __init__(self, tokenizer_data: TokenEnforcerTokenizerData, parser: Character
self.eos_token_id = tokenizer_data.eos_token_id
self.regular_tokens = tokenizer_data.regular_tokens
self.allowed_token_cache: Dict[Hashable, List[int]] = {}

config = CharacterLevelParserConfig(alphabet=tokenizer_data.tokenizer_alphabet)
parser.config = config

if parser.config.superfast_mode:
for token_idx, decoded, _ in self.regular_tokens:
if decoded == '"':
self.allowed_token_cache["superfast"] = [token_idx]

parser.config.alphabet = tokenizer_data.tokenizer_alphabet

def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]:
"""
Expand Down Expand Up @@ -85,6 +90,8 @@ def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]:
self._compute_allowed_tokens(sent_tuple, new_state)
return new_state.allowed_tokens



def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.OutputTensorState'):
try:
allowed_tokens: List[int] = []
Expand Down Expand Up @@ -123,7 +130,7 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token
relevant_characters = tree_node.children.keys()
# This next line is the heart of the traversal algorithm. We only explore paths that are shared by both the parser and the tokenizer.
characters_to_explore = set(relevant_characters).intersection(allowed_characters)

# Performance optimization: If we are in JSON freetext, all of the tokens that don't contain quote, or end with quote, are legal, so we take
# their cached list. If the quote character is allowed, we only need to dynamically explore the cases where the string starts with a quote.
# This breaks the elegance of the API, but otherwise it is a huge performance hit.
Expand All @@ -137,7 +144,6 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token

allowed_tokens.extend(cache.lookup_allowed_tokens(min_remaining, max_allowed_len))
characters_to_explore = characters_to_explore.intersection(['"'])

for character in characters_to_explore:
next_parser = parser.add_character(character)
next_tree_node = tree_node.children[character]
Expand All @@ -154,6 +160,11 @@ def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_
prev_decoded = self.decoder(state.current_word_tokens)
new_decoded = self.decoder(new_state.current_word_tokens)
new_characters = new_decoded[len(prev_decoded):]

if len(new_characters) == 1 and self.tokenizer_tree.tokens_to_strs.get(token_sequence[-2]) == '�' and self.tokenizer_tree.tokens_to_strs[new_token] == '�':
decoded_unicode_char = self.decoder(token_sequence[-2:])
new_characters = 'X'*len(decoded_unicode_char)

for character in new_characters:
try:
new_state.parser = new_state.parser.add_character(character)
Expand Down
Loading