Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/faster json freetext #76

Merged
merged 5 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out
raise
except Exception:
# Other exceptions are potential bugs and should be reported
logging.basicConfig(level=logging.ERROR) # Initialize if no loggers
prefix = self.decoder(list(state_tokens))
logging.exception(f"Unknown LMFormatEnforcer Problem. Prefix: '{prefix}'\n"
"Terminating the parser. Please open an issue at \n"
Expand Down
95 changes: 57 additions & 38 deletions lmformatenforcer/tokenizerprefixtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,42 @@ class JsonFreetextTokenCache:
a separate allowlist for all possible constraint states up to maximum token length (16 in Llama, for example).
After deduplication, this results in about ~75 lists for the Llama tokenizer.
"""
class _StringLengthTokenCache:
"""This is an internal data structure, that given a list of string+token pairs,
can quickly return all token ids of strings between certain lengths"""
def __init__(self):
self.tokens: List[int] = []
self.first_index_geq_than_length: List[int] = [0]

def build(self, token_strs_to_idx: List[Tuple[str, int]]):
# TODO: If this becomes a performance bottleneck, bucket sort instead.
token_strs_to_idx = sorted(token_strs_to_idx, key=lambda p:len(p[0]))
self.tokens = [pair[1] for pair in token_strs_to_idx]
# self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging
token_lengths = [len(pair[0]) for pair in token_strs_to_idx]
for idx, token_length in enumerate(token_lengths):
while len(self.first_index_geq_than_length) <= token_length:
self.first_index_geq_than_length.append(idx)
self.first_index_geq_than_length.append(len(token_lengths))

def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]:
if min_length >= len(self.first_index_geq_than_length):
return []
start_index = self.first_index_geq_than_length[min_length] if min_length > 0 else 0
if max_length == 0:
end_index = 0
elif max_length + 1 < len(self.first_index_geq_than_length):
end_index = self.first_index_geq_than_length[max_length + 1]
else:
end_index = len(self.tokens)
return self.tokens[start_index:end_index]

def __init__(self, ) -> None:
self.token_str_to_num: Dict[str, int] = {}
self.token_num_to_str: Dict[int, str] = {}
self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {}
self.max_token_len = 0
self.max_allowed_token_len = 32
self.regular_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache()
self.quote_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache()

def add_token(self, token_str: str, token_int: int):
assert not self.allowlist_cache, "Cannot add more tokens after allowlists were precalculated"
Expand All @@ -39,56 +70,44 @@ def add_token(self, token_str: str, token_int: int):
# TODO: Should we instead ALWAYS allow them?
return

self.token_str_to_num[token_str] = token_int
self.max_token_len = min(max(self.max_token_len, len(token_str)), self.max_allowed_token_len)
self.token_num_to_str[token_int] = token_str

def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]:
"""
Get the list of tokens that are allowed within a JSON string, such that:
1. all candidate tokens are at most `max_len` characters long (excluding the trailing quote), and
2. if a token ends with a quote, it's at least `min_remaining` chars long (excluding the quote).
"""
return self.allowlist_cache[(min_remaining, max_len)]
cache_key = (min_remaining, max_len)
if cache_key not in self.allowlist_cache:
tokens_with_quote = self.quote_tokens_length_cache.get_indices_between_length(min_remaining + 1, max_len + 1)
tokens_without_quote = self.regular_tokens_length_cache.get_indices_between_length(-1, max_len)
combined = tokens_with_quote + tokens_without_quote
self.allowlist_cache[cache_key] = tuple(combined)
return self.allowlist_cache[cache_key]

def freeze(self) -> None:
"""
Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len`
based on the tokens that were added with `add_token()`.
"""
all_tokens: List[str] = sorted(self.token_str_to_num.keys())
all_tokens: List[Tuple[str, int]] = list((s, n) for n,s in self.token_num_to_str.items())
assert all_tokens, "Cannot precalculate allowlists for an empty token list"
assert not any(t == '' for t in all_tokens), "Tokenizer must not contain empty tokens"

def _valid_for_min_remaining(token, min_remaining):
return not token.endswith('"') or len(token.rstrip('"')) >= min_remaining

def _valid_for_max_len(token, max_len):
return len(token.rstrip('"')) <= max_len

# Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)`
token_lists = {}
for min_remaining in range(self.max_token_len + 1):
for max_len in range(self.max_token_len + 1):
if max_len >= min_remaining: # Skip combinations that are never used
token_lists[(min_remaining, max_len)] = tuple(sorted([
token for token in all_tokens
if _valid_for_min_remaining(token, min_remaining) and _valid_for_max_len(token, max_len)
]))

# Deduplicate the lists to save RAM as many of them will be identical
unique_lists = set(token_lists.values())
for key, lst in token_lists.items():
for uniq in unique_lists:
if len(uniq) == len(lst) and uniq == lst:
token_lists[key] = uniq
break

# Turn token strings into token numbers
self.allowlist_cache = {
key: tuple(self.token_str_to_num[t] for t in lst)
for key, lst in token_lists.items()
}
del self.token_str_to_num
assert not any(pair[0] == '' for pair in all_tokens), "Tokenizer must not contain empty tokens"

regular_tokens: List[Tuple[str, int]] = []
quote_tokens: List[Tuple[str, int]] = []
for pair in all_tokens:
if pair[0].endswith('"'):
quote_tokens.append(pair)
else:
regular_tokens.append(pair)

self.regular_tokens_length_cache.build(regular_tokens)
self.quote_tokens_length_cache.build(quote_tokens)
self.max_token_len = max(len(self.regular_tokens_length_cache.first_index_geq_than_length),
len(self.quote_tokens_length_cache.first_index_geq_than_length))
del self.token_num_to_str


class TokenizerPrefixTree:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_tokenizercaching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

from lmformatenforcer.tokenizerprefixtree import JsonFreetextTokenCache

def test_json_freetext_cache():
token_to_str = {}
cache = JsonFreetextTokenCache()
test_length = 500
letters = "abcde"
num_letters = len(letters)
def _register_token(token_str: str):
token_idx = len(token_to_str)
token_to_str[token_idx] = token_str
cache.add_token(token_str, token_idx)
_register_token("\"")
for i in range(1, test_length):
for letter in letters:
_register_token(letter * i)
_register_token(letter * i + '"')
cache.freeze()

def _assert_allowed_tokens(_min_remaining, _max_length, _num_expected_tokens):
allowed_tokens = cache.lookup_allowed_tokens(_min_remaining, _max_length)
if len(allowed_tokens) != _num_expected_tokens:
allowed_token_strs = "|".join(token_to_str[token_idx] for token_idx in allowed_tokens)
raise Exception(f"Min={_min_remaining}, Max={_max_length}, Expected {_num_expected_tokens}, got {len(allowed_tokens)} : {allowed_token_strs}")

for min_remaining in range(0, test_length):
for max_length in range(min_remaining, test_length):

num_expected_quote_tokens = num_letters * (max_length - min_remaining + 1)
if min_remaining == 0:
# at 0, there is only one quoted string (")
num_expected_quote_tokens -= (num_letters - 1)

num_expected_regular_tokens = max_length * num_letters
num_expected_tokens = num_expected_quote_tokens + num_expected_regular_tokens
_assert_allowed_tokens(min_remaining, max_length, num_expected_tokens)

_assert_allowed_tokens(0, test_length + 1, len(token_to_str))
num_nonquote_tokens = (test_length - 1) * num_letters
_assert_allowed_tokens(test_length + 1, test_length + 1, num_nonquote_tokens)
Loading