diff --git a/lmformatenforcer/integrations/exllamav2.py b/lmformatenforcer/integrations/exllamav2.py index 7fa9ee4..31d822f 100644 --- a/lmformatenforcer/integrations/exllamav2.py +++ b/lmformatenforcer/integrations/exllamav2.py @@ -10,22 +10,21 @@ def _build_regular_tokens_list(tokenizer: ExLlamaV2Tokenizer) -> List[Tuple[int, str, bool]]: - token_0 = tokenizer.encode("0")[0] - regular_tokens = [] vocab_size = tokenizer.tokenizer.vocab_size() - all_special_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id] + all_special_ids = set(tokenizer.extended_id_to_piece.keys()) + all_special_ids.update({ tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id }) + id_to_piece = tokenizer.get_id_to_piece_list() + regular_tokens = [] for token_idx in range(vocab_size): if token_idx in all_special_ids: continue - # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. - tensor_after_0 = torch.tensor(token_0.tolist() + [token_idx], dtype=torch.long) - decoded_after_0 = tokenizer.decode(tensor_after_0)[1:] - decoded_regular = tokenizer.decode(token_0) - is_word_start_token = len(decoded_after_0) > len(decoded_regular) - regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + decoded = id_to_piece[token_idx] + is_word_start_token = len(decoded) > 0 and decoded[0] == " " + regular_tokens.append((token_idx, decoded, is_word_start_token)) return regular_tokens + def build_token_enforcer_tokenizer_data(tokenizer: ExLlamaV2Tokenizer) -> TokenEnforcerTokenizerData: regular_tokens = _build_regular_tokens_list(tokenizer)