From d239e263e285d4d80a90cb63d8c502f5f9dc3bf5 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Sat, 17 Feb 2024 23:02:32 +0200 Subject: [PATCH 1/5] Added unit test for json freetext token filtering --- tests/test_tokenizercaching.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/test_tokenizercaching.py diff --git a/tests/test_tokenizercaching.py b/tests/test_tokenizercaching.py new file mode 100644 index 0000000..ac9e6f6 --- /dev/null +++ b/tests/test_tokenizercaching.py @@ -0,0 +1,22 @@ + +from lmformatenforcer.tokenizerprefixtree import JsonFreetextTokenCache + + +def test_json_freetext_cache(): + token_to_str = {} + cache = JsonFreetextTokenCache() + test_length = 20 + cache.max_allowed_token_len = test_length + def _register_token(token_idx: int, token_str: str): + token_to_str[token_idx] = token_str + cache.add_token(token_str, token_idx) + for i in range(1, test_length): + _register_token(i, "a" * i) + _register_token(i + cache.max_allowed_token_len, "a" * i + '"') + cache.freeze() + for min_remaining in range(1, test_length): + for max_length in range(min_remaining, test_length): + allowed_tokens = cache.lookup_allowed_tokens(min_remaining, max_length) + num_expected_quote_tokens = max_length - min_remaining + 1 + num_expected_regular_tokens = max_length + assert len(allowed_tokens) == num_expected_quote_tokens + num_expected_regular_tokens From 4917d643f130ce6ea9d32547179fe94833cdabaa Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Sat, 17 Feb 2024 23:41:48 +0200 Subject: [PATCH 2/5] WIP: Faster algo --- lmformatenforcer/tokenizerprefixtree.py | 83 +++++++++++++++---------- tests/test_tokenizercaching.py | 5 +- 2 files changed, 51 insertions(+), 37 deletions(-) diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index 3bfd82c..03b7f2e 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -16,11 +16,38 @@ class JsonFreetextTokenCache: a separate allowlist for all possible constraint states up to maximum token length (16 in Llama, for example). After deduplication, this results in about ~75 lists for the Llama tokenizer. """ + class _StringLengthTokenCache: + def __init__(self): + self.tokens: List[int] = [] + self.first_index_geq_than_length: List[int] = [0] + + def build(self, token_strs_to_idx: List[Tuple[str, int]]): + token_strs_to_idx = sorted(token_strs_to_idx, key=lambda p:len(p[0])) + self.tokens = [pair[1] for pair in token_strs_to_idx] + self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging + token_lengths = [len(pair[0]) for pair in token_strs_to_idx] + max_length = token_lengths[-1] + for idx, token_length in enumerate(token_lengths): + while len(self.first_index_geq_than_length) < token_length: + self.first_index_geq_than_length.append(idx) + self.first_index_geq_than_length.append(max_length) + + def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: + if min_length > len(self.first_index_geq_than_length): + return [] + start_index = self.first_index_geq_than_length[min_length] if min_length > 0 else 0 + if max_length > 0 and max_length + 1 < len(self.first_index_geq_than_length): + end_index = max_length + 1 + else: + end_index = -1 + return self.tokens[start_index:end_index] + def __init__(self, ) -> None: self.token_str_to_num: Dict[str, int] = {} self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {} self.max_token_len = 0 - self.max_allowed_token_len = 32 + self.regular_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache() + self.quote_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache() def add_token(self, token_str: str, token_int: int): assert not self.allowlist_cache, "Cannot add more tokens after allowlists were precalculated" @@ -40,7 +67,6 @@ def add_token(self, token_str: str, token_int: int): return self.token_str_to_num[token_str] = token_int - self.max_token_len = min(max(self.max_token_len, len(token_str)), self.max_allowed_token_len) def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]: """ @@ -48,6 +74,12 @@ def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, 1. all candidate tokens are at most `max_len` characters long (excluding the trailing quote), and 2. if a token ends with a quote, it's at least `min_remaining` chars long (excluding the quote). """ + cache_key = (min_remaining, max_len) + if cache_key not in self.allowlist_cache: + tokens_with_quote = self.quote_tokens_length_cache.get_indices_between_length(min_remaining + 1, max_len + 1) + tokens_without_quote = self.regular_tokens_length_cache.get_indices_between_length(-1, max_len) + combined = tokens_with_quote + tokens_without_quote + self.allowlist_cache[cache_key] = tuple(combined) return self.allowlist_cache[(min_remaining, max_len)] def freeze(self) -> None: @@ -55,39 +87,22 @@ def freeze(self) -> None: Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len` based on the tokens that were added with `add_token()`. """ - all_tokens: List[str] = sorted(self.token_str_to_num.keys()) + all_tokens: List[Tuple[str, int]] = list(self.token_str_to_num.items()) assert all_tokens, "Cannot precalculate allowlists for an empty token list" - assert not any(t == '' for t in all_tokens), "Tokenizer must not contain empty tokens" - - def _valid_for_min_remaining(token, min_remaining): - return not token.endswith('"') or len(token.rstrip('"')) >= min_remaining - - def _valid_for_max_len(token, max_len): - return len(token.rstrip('"')) <= max_len - - # Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)` - token_lists = {} - for min_remaining in range(self.max_token_len + 1): - for max_len in range(self.max_token_len + 1): - if max_len >= min_remaining: # Skip combinations that are never used - token_lists[(min_remaining, max_len)] = tuple(sorted([ - token for token in all_tokens - if _valid_for_min_remaining(token, min_remaining) and _valid_for_max_len(token, max_len) - ])) - - # Deduplicate the lists to save RAM as many of them will be identical - unique_lists = set(token_lists.values()) - for key, lst in token_lists.items(): - for uniq in unique_lists: - if len(uniq) == len(lst) and uniq == lst: - token_lists[key] = uniq - break - - # Turn token strings into token numbers - self.allowlist_cache = { - key: tuple(self.token_str_to_num[t] for t in lst) - for key, lst in token_lists.items() - } + assert not any(pair[0] == '' for pair in all_tokens), "Tokenizer must not contain empty tokens" + + regular_tokens: List[Tuple[str, int]] = [] + quote_tokens: List[Tuple[str, int]] = [] + for pair in all_tokens: + if pair[0].endswith('"'): + quote_tokens.append(pair) + else: + regular_tokens.append(pair) + + self.regular_tokens_length_cache.build(regular_tokens) + self.quote_tokens_length_cache.build(quote_tokens) + self.max_token_len = max(len(self.regular_tokens_length_cache.first_index_geq_than_length), + len(self.quote_tokens_length_cache.first_index_geq_than_length)) del self.token_str_to_num diff --git a/tests/test_tokenizercaching.py b/tests/test_tokenizercaching.py index ac9e6f6..3ae9d3e 100644 --- a/tests/test_tokenizercaching.py +++ b/tests/test_tokenizercaching.py @@ -5,14 +5,13 @@ def test_json_freetext_cache(): token_to_str = {} cache = JsonFreetextTokenCache() - test_length = 20 - cache.max_allowed_token_len = test_length + test_length = 200 def _register_token(token_idx: int, token_str: str): token_to_str[token_idx] = token_str cache.add_token(token_str, token_idx) for i in range(1, test_length): _register_token(i, "a" * i) - _register_token(i + cache.max_allowed_token_len, "a" * i + '"') + _register_token(i + test_length, "a" * i + '"') cache.freeze() for min_remaining in range(1, test_length): for max_length in range(min_remaining, test_length): From b2f451bd11e09af6f8d871f9e09bfe9d0ad3661d Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Sun, 18 Feb 2024 20:55:17 +0200 Subject: [PATCH 3/5] Algo works almost 100% --- lmformatenforcer/tokenizerprefixtree.py | 13 ++++++----- tests/test_tokenizercaching.py | 29 +++++++++++++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index 03b7f2e..8856cf4 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -24,20 +24,21 @@ def __init__(self): def build(self, token_strs_to_idx: List[Tuple[str, int]]): token_strs_to_idx = sorted(token_strs_to_idx, key=lambda p:len(p[0])) self.tokens = [pair[1] for pair in token_strs_to_idx] - self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging + # self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging token_lengths = [len(pair[0]) for pair in token_strs_to_idx] - max_length = token_lengths[-1] for idx, token_length in enumerate(token_lengths): - while len(self.first_index_geq_than_length) < token_length: + while len(self.first_index_geq_than_length) <= token_length: self.first_index_geq_than_length.append(idx) - self.first_index_geq_than_length.append(max_length) + self.first_index_geq_than_length.append(len(token_lengths)) def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: if min_length > len(self.first_index_geq_than_length): return [] start_index = self.first_index_geq_than_length[min_length] if min_length > 0 else 0 - if max_length > 0 and max_length + 1 < len(self.first_index_geq_than_length): - end_index = max_length + 1 + if max_length == 0: + end_index = 0 + elif max_length + 1 < len(self.first_index_geq_than_length): + end_index = self.first_index_geq_than_length[max_length + 1] else: end_index = -1 return self.tokens[start_index:end_index] diff --git a/tests/test_tokenizercaching.py b/tests/test_tokenizercaching.py index 3ae9d3e..b00e475 100644 --- a/tests/test_tokenizercaching.py +++ b/tests/test_tokenizercaching.py @@ -1,21 +1,32 @@ from lmformatenforcer.tokenizerprefixtree import JsonFreetextTokenCache - def test_json_freetext_cache(): token_to_str = {} cache = JsonFreetextTokenCache() - test_length = 200 - def _register_token(token_idx: int, token_str: str): + test_length = 500 + letters = "abcdefg" + num_letters = len(letters) + def _register_token(token_str: str): + token_idx = len(token_to_str) token_to_str[token_idx] = token_str cache.add_token(token_str, token_idx) + _register_token("\"") for i in range(1, test_length): - _register_token(i, "a" * i) - _register_token(i + test_length, "a" * i + '"') + for letter in letters: + _register_token(letter * i) + _register_token(letter * i + '"') cache.freeze() - for min_remaining in range(1, test_length): + for min_remaining in range(0, test_length): for max_length in range(min_remaining, test_length): allowed_tokens = cache.lookup_allowed_tokens(min_remaining, max_length) - num_expected_quote_tokens = max_length - min_remaining + 1 - num_expected_regular_tokens = max_length - assert len(allowed_tokens) == num_expected_quote_tokens + num_expected_regular_tokens + num_expected_quote_tokens = num_letters * (max_length - min_remaining + 1) + if min_remaining == 0: + # at 0, there is only one quoted string (") + num_expected_quote_tokens -= (num_letters - 1) + + num_expected_regular_tokens = max_length * num_letters + num_expected_tokens = num_expected_quote_tokens + num_expected_regular_tokens + if len(allowed_tokens) != num_expected_tokens: + allowed_token_strs = "|".join(token_to_str[token_idx] for token_idx in allowed_tokens) + raise Exception(f"Min={min_remaining}, Max={max_length}, Expected {num_expected_tokens}, got {len(allowed_tokens)} : {allowed_token_strs}") From b13ee1c17f3f0e4ff54460b6032c13428e93a801 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Sun, 18 Feb 2024 21:04:00 +0200 Subject: [PATCH 4/5] json freetext tokenizer cache allows multiple tokens of same str --- lmformatenforcer/tokenizerprefixtree.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index 8856cf4..fb39d18 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -44,7 +44,7 @@ def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: return self.tokens[start_index:end_index] def __init__(self, ) -> None: - self.token_str_to_num: Dict[str, int] = {} + self.token_num_to_str: Dict[int, str] = {} self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {} self.max_token_len = 0 self.regular_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache() @@ -67,7 +67,7 @@ def add_token(self, token_str: str, token_int: int): # TODO: Should we instead ALWAYS allow them? return - self.token_str_to_num[token_str] = token_int + self.token_num_to_str[token_int] = token_str def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]: """ @@ -81,14 +81,14 @@ def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, tokens_without_quote = self.regular_tokens_length_cache.get_indices_between_length(-1, max_len) combined = tokens_with_quote + tokens_without_quote self.allowlist_cache[cache_key] = tuple(combined) - return self.allowlist_cache[(min_remaining, max_len)] + return self.allowlist_cache[cache_key] def freeze(self) -> None: """ Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len` based on the tokens that were added with `add_token()`. """ - all_tokens: List[Tuple[str, int]] = list(self.token_str_to_num.items()) + all_tokens: List[Tuple[str, int]] = list((s, n) for n,s in self.token_num_to_str.items()) assert all_tokens, "Cannot precalculate allowlists for an empty token list" assert not any(pair[0] == '' for pair in all_tokens), "Tokenizer must not contain empty tokens" @@ -104,7 +104,7 @@ def freeze(self) -> None: self.quote_tokens_length_cache.build(quote_tokens) self.max_token_len = max(len(self.regular_tokens_length_cache.first_index_geq_than_length), len(self.quote_tokens_length_cache.first_index_geq_than_length)) - del self.token_str_to_num + del self.token_num_to_str class TokenizerPrefixTree: From 2776c230066036a2cfc993083593acfa30b31581 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Mon, 19 Feb 2024 06:10:54 +0200 Subject: [PATCH 5/5] Fixing off by one errors, all tests pass --- lmformatenforcer/tokenenforcer.py | 1 + lmformatenforcer/tokenizerprefixtree.py | 7 +++++-- tests/test_tokenizercaching.py | 19 ++++++++++++++----- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index a990153..668ac7d 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -109,6 +109,7 @@ def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.Out raise except Exception: # Other exceptions are potential bugs and should be reported + logging.basicConfig(level=logging.ERROR) # Initialize if no loggers prefix = self.decoder(list(state_tokens)) logging.exception(f"Unknown LMFormatEnforcer Problem. Prefix: '{prefix}'\n" "Terminating the parser. Please open an issue at \n" diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index fb39d18..aa04f54 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -17,11 +17,14 @@ class JsonFreetextTokenCache: After deduplication, this results in about ~75 lists for the Llama tokenizer. """ class _StringLengthTokenCache: + """This is an internal data structure, that given a list of string+token pairs, + can quickly return all token ids of strings between certain lengths""" def __init__(self): self.tokens: List[int] = [] self.first_index_geq_than_length: List[int] = [0] def build(self, token_strs_to_idx: List[Tuple[str, int]]): + # TODO: If this becomes a performance bottleneck, bucket sort instead. token_strs_to_idx = sorted(token_strs_to_idx, key=lambda p:len(p[0])) self.tokens = [pair[1] for pair in token_strs_to_idx] # self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging @@ -32,7 +35,7 @@ def build(self, token_strs_to_idx: List[Tuple[str, int]]): self.first_index_geq_than_length.append(len(token_lengths)) def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: - if min_length > len(self.first_index_geq_than_length): + if min_length >= len(self.first_index_geq_than_length): return [] start_index = self.first_index_geq_than_length[min_length] if min_length > 0 else 0 if max_length == 0: @@ -40,7 +43,7 @@ def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: elif max_length + 1 < len(self.first_index_geq_than_length): end_index = self.first_index_geq_than_length[max_length + 1] else: - end_index = -1 + end_index = len(self.tokens) return self.tokens[start_index:end_index] def __init__(self, ) -> None: diff --git a/tests/test_tokenizercaching.py b/tests/test_tokenizercaching.py index b00e475..fc837f1 100644 --- a/tests/test_tokenizercaching.py +++ b/tests/test_tokenizercaching.py @@ -5,7 +5,7 @@ def test_json_freetext_cache(): token_to_str = {} cache = JsonFreetextTokenCache() test_length = 500 - letters = "abcdefg" + letters = "abcde" num_letters = len(letters) def _register_token(token_str: str): token_idx = len(token_to_str) @@ -17,9 +17,16 @@ def _register_token(token_str: str): _register_token(letter * i) _register_token(letter * i + '"') cache.freeze() + + def _assert_allowed_tokens(_min_remaining, _max_length, _num_expected_tokens): + allowed_tokens = cache.lookup_allowed_tokens(_min_remaining, _max_length) + if len(allowed_tokens) != _num_expected_tokens: + allowed_token_strs = "|".join(token_to_str[token_idx] for token_idx in allowed_tokens) + raise Exception(f"Min={_min_remaining}, Max={_max_length}, Expected {_num_expected_tokens}, got {len(allowed_tokens)} : {allowed_token_strs}") + for min_remaining in range(0, test_length): for max_length in range(min_remaining, test_length): - allowed_tokens = cache.lookup_allowed_tokens(min_remaining, max_length) + num_expected_quote_tokens = num_letters * (max_length - min_remaining + 1) if min_remaining == 0: # at 0, there is only one quoted string (") @@ -27,6 +34,8 @@ def _register_token(token_str: str): num_expected_regular_tokens = max_length * num_letters num_expected_tokens = num_expected_quote_tokens + num_expected_regular_tokens - if len(allowed_tokens) != num_expected_tokens: - allowed_token_strs = "|".join(token_to_str[token_idx] for token_idx in allowed_tokens) - raise Exception(f"Min={min_remaining}, Max={max_length}, Expected {num_expected_tokens}, got {len(allowed_tokens)} : {allowed_token_strs}") + _assert_allowed_tokens(min_remaining, max_length, num_expected_tokens) + + _assert_allowed_tokens(0, test_length + 1, len(token_to_str)) + num_nonquote_tokens = (test_length - 1) * num_letters + _assert_allowed_tokens(test_length + 1, test_length + 1, num_nonquote_tokens)