From d34296ab509d405b173763325395db9c1b4d5919 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 18 Feb 2024 19:08:42 +0100 Subject: [PATCH 1/3] Extract vocab from ExLlamaV2Tokenizer id_to_piece list --- lmformatenforcer/integrations/exllamav2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lmformatenforcer/integrations/exllamav2.py b/lmformatenforcer/integrations/exllamav2.py index 7fa9ee4..31d822f 100644 --- a/lmformatenforcer/integrations/exllamav2.py +++ b/lmformatenforcer/integrations/exllamav2.py @@ -10,22 +10,21 @@ def _build_regular_tokens_list(tokenizer: ExLlamaV2Tokenizer) -> List[Tuple[int, str, bool]]: - token_0 = tokenizer.encode("0")[0] - regular_tokens = [] vocab_size = tokenizer.tokenizer.vocab_size() - all_special_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id] + all_special_ids = set(tokenizer.extended_id_to_piece.keys()) + all_special_ids.update({ tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id }) + id_to_piece = tokenizer.get_id_to_piece_list() + regular_tokens = [] for token_idx in range(vocab_size): if token_idx in all_special_ids: continue - # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. - tensor_after_0 = torch.tensor(token_0.tolist() + [token_idx], dtype=torch.long) - decoded_after_0 = tokenizer.decode(tensor_after_0)[1:] - decoded_regular = tokenizer.decode(token_0) - is_word_start_token = len(decoded_after_0) > len(decoded_regular) - regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + decoded = id_to_piece[token_idx] + is_word_start_token = len(decoded) > 0 and decoded[0] == " " + regular_tokens.append((token_idx, decoded, is_word_start_token)) return regular_tokens + def build_token_enforcer_tokenizer_data(tokenizer: ExLlamaV2Tokenizer) -> TokenEnforcerTokenizerData: regular_tokens = _build_regular_tokens_list(tokenizer) From 980b632e10146ff7ae75681fab28f2e514a366c6 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 18 Feb 2024 20:52:51 +0100 Subject: [PATCH 2/3] Optimize initialization of JsonFreetextTokenCache --- lmformatenforcer/tokenizerprefixtree.py | 44 +++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index 3bfd82c..a3d7c08 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -17,9 +17,10 @@ class JsonFreetextTokenCache: After deduplication, this results in about ~75 lists for the Llama tokenizer. """ def __init__(self, ) -> None: - self.token_str_to_num: Dict[str, int] = {} + self.token_num_to_str: Dict[int, str] = {} self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {} self.max_token_len = 0 + self.max_token_int = -1 self.max_allowed_token_len = 32 def add_token(self, token_str: str, token_int: int): @@ -39,7 +40,8 @@ def add_token(self, token_str: str, token_int: int): # TODO: Should we instead ALWAYS allow them? return - self.token_str_to_num[token_str] = token_int + self.token_num_to_str[token_int] = token_str + self.max_token_int = max(self.max_token_int, token_int) self.max_token_len = min(max(self.max_token_len, len(token_str)), self.max_allowed_token_len) def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]: @@ -55,25 +57,37 @@ def freeze(self) -> None: Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len` based on the tokens that were added with `add_token()`. """ - all_tokens: List[str] = sorted(self.token_str_to_num.keys()) + all_tokens: List[str] = [self.token_num_to_str.get(i, None) for i in range(self.max_token_int + 1)] assert all_tokens, "Cannot precalculate allowlists for an empty token list" assert not any(t == '' for t in all_tokens), "Tokenizer must not contain empty tokens" def _valid_for_min_remaining(token, min_remaining): - return not token.endswith('"') or len(token.rstrip('"')) >= min_remaining + return token is not None and (not token.endswith('"') or len(token.rstrip('"')) >= min_remaining) def _valid_for_max_len(token, max_len): - return len(token.rstrip('"')) <= max_len + return token is not None and len(token.rstrip('"')) <= max_len + + # Precalculate valid token sets + valid_for_min_remaining_sets = [] + for min_remaining in range(self.max_token_len + 1): + valid_for_min_remaining_sets.append(set([ + i for i in range(len(all_tokens)) + if _valid_for_min_remaining(all_tokens[i], min_remaining) + ])) + + valid_for_max_len_sets = [] + for max_len in range(self.max_token_len + 1): + valid_for_max_len_sets.append(set([ + i for i in range(len(all_tokens)) + if _valid_for_max_len(all_tokens[i], max_len) + ])) # Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)` token_lists = {} for min_remaining in range(self.max_token_len + 1): - for max_len in range(self.max_token_len + 1): - if max_len >= min_remaining: # Skip combinations that are never used - token_lists[(min_remaining, max_len)] = tuple(sorted([ - token for token in all_tokens - if _valid_for_min_remaining(token, min_remaining) and _valid_for_max_len(token, max_len) - ])) + for max_len in range(min_remaining, self.max_token_len + 1): + ids = tuple(valid_for_min_remaining_sets[min_remaining] & valid_for_max_len_sets[max_len]) + token_lists[(min_remaining, max_len)] = ids # Deduplicate the lists to save RAM as many of them will be identical unique_lists = set(token_lists.values()) @@ -83,12 +97,8 @@ def _valid_for_max_len(token, max_len): token_lists[key] = uniq break - # Turn token strings into token numbers - self.allowlist_cache = { - key: tuple(self.token_str_to_num[t] for t in lst) - for key, lst in token_lists.items() - } - del self.token_str_to_num + self.allowlist_cache = token_lists + del self.token_num_to_str class TokenizerPrefixTree: From 4695657fd5c91ca5cc8e2bd5144236b4b35b49ee Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Mon, 19 Feb 2024 03:42:58 +0100 Subject: [PATCH 3/3] Precompute token lengths and simplify set construction (current bottleneck), further 2x speedup --- lmformatenforcer/tokenizerprefixtree.py | 36 ++++++++++++++----------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index a3d7c08..a610c51 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -61,41 +61,47 @@ def freeze(self) -> None: assert all_tokens, "Cannot precalculate allowlists for an empty token list" assert not any(t == '' for t in all_tokens), "Tokenizer must not contain empty tokens" - def _valid_for_min_remaining(token, min_remaining): - return token is not None and (not token.endswith('"') or len(token.rstrip('"')) >= min_remaining) + # Precalculate lengths + minvalue = -10 + maxvalue = self.max_token_len + 10 - def _valid_for_max_len(token, max_len): - return token is not None and len(token.rstrip('"')) <= max_len + all_tokens_min = [ + minvalue if (t is None) else (maxvalue if not t.endswith('"') else len(t.rstrip('"'))) + for t in all_tokens + ] + + all_tokens_max = [ + maxvalue if (t is None) else len(t.rstrip('"')) + for t in all_tokens + ] # Precalculate valid token sets valid_for_min_remaining_sets = [] for min_remaining in range(self.max_token_len + 1): valid_for_min_remaining_sets.append(set([ i for i in range(len(all_tokens)) - if _valid_for_min_remaining(all_tokens[i], min_remaining) + if all_tokens_min[i] >= min_remaining ])) valid_for_max_len_sets = [] for max_len in range(self.max_token_len + 1): valid_for_max_len_sets.append(set([ i for i in range(len(all_tokens)) - if _valid_for_max_len(all_tokens[i], max_len) + if all_tokens_max[i] <= max_len ])) # Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)` + # As many of them will be identical, avoid storing duplicate lists to save RAM token_lists = {} + unique_lists = {} for min_remaining in range(self.max_token_len + 1): for max_len in range(min_remaining, self.max_token_len + 1): ids = tuple(valid_for_min_remaining_sets[min_remaining] & valid_for_max_len_sets[max_len]) - token_lists[(min_remaining, max_len)] = ids - - # Deduplicate the lists to save RAM as many of them will be identical - unique_lists = set(token_lists.values()) - for key, lst in token_lists.items(): - for uniq in unique_lists: - if len(uniq) == len(lst) and uniq == lst: - token_lists[key] = uniq - break + if ids in unique_lists: + token_lists[(min_remaining, max_len)] = unique_lists[ids] # Save reference + else: + token_lists[(min_remaining, max_len)] = ids # Save new unique set + unique_lists[ids] = ids self.allowlist_cache = token_lists del self.token_num_to_str