-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Tiktoken support for TRTLLM #10306
Open
meatybobby
wants to merge
13
commits into
main
Choose a base branch
from
bobchen/add_tiktoken
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+145
−4
Open
Changes from 11 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
16e8eb7
Add tiktoken tokenizer
meatybobby 4cd21a3
Merge branch 'main' into bobchen/add_tiktoken
meatybobby f24af5f
Add special token
meatybobby a61a02f
Remove unused import
meatybobby b857102
Apply isort and black reformatting
meatybobby 672f4ef
Remove unused import
meatybobby b1b7a50
Merge branch 'main' into bobchen/add_tiktoken
meatybobby 9df9f91
Merge branch 'main' into bobchen/add_tiktoken
jubick1337 38b8042
Merge branch 'main' into bobchen/add_tiktoken
meatybobby 986966d
Fix after merge
meatybobby b367e08
Merge branch 'main' into bobchen/add_tiktoken
meatybobby 57762b9
Change qnemo loading
meatybobby 5806398
Apply isort and black reformatting
meatybobby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import base64 | ||
import json | ||
from pathlib import Path | ||
from typing import Dict, Optional | ||
|
||
import numpy as np | ||
import tiktoken | ||
import torch | ||
|
||
PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" | ||
DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072 | ||
SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"] | ||
SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>" | ||
|
||
|
||
def reload_mergeable_ranks( | ||
path: str, | ||
max_vocab: Optional[int] = None, | ||
) -> Dict[bytes, int]: | ||
""" | ||
Reload the tokenizer JSON file and convert it to Tiktoken format. | ||
""" | ||
assert path.endswith(".json") | ||
|
||
# reload vocab | ||
with open(path, "r", encoding='utf-8') as f: | ||
vocab = json.load(f) | ||
assert isinstance(vocab, list) | ||
print(f"Vocab size: {len(vocab)}") | ||
if max_vocab is not None: | ||
vocab = vocab[:max_vocab] | ||
print(f"Cutting vocab to first {len(vocab)} tokens.") | ||
|
||
# build ranks | ||
ranks: Dict[bytes, int] = {} | ||
for i, x in enumerate(vocab): | ||
assert x.keys() == {"rank", "token_bytes", "token_str"} | ||
assert x["rank"] == i | ||
merge = base64.b64decode(x["token_bytes"]) | ||
assert i >= 256 or merge == bytes([i]) | ||
ranks[merge] = x["rank"] | ||
|
||
# sanity check | ||
assert len(ranks) == len(vocab) | ||
assert set(ranks.values()) == set(range(len(ranks))) | ||
|
||
return ranks | ||
|
||
|
||
class TiktokenTokenizer: | ||
def __init__(self, vocab_file: str): | ||
|
||
self.num_special_tokens = 1000 | ||
vocab_size = DEFAULT_TIKTOKEN_MAX_VOCAB | ||
pattern = PATTERN_TIKTOKEN | ||
special_tokens = SPECIAL_TOKENS.copy() | ||
inner_vocab_size = vocab_size - self.num_special_tokens | ||
|
||
token2id = reload_mergeable_ranks(vocab_file, max_vocab=inner_vocab_size) | ||
self.tokenizer = tiktoken.Encoding( | ||
name=Path(vocab_file).parent.name, | ||
pat_str=pattern, | ||
mergeable_ranks=token2id, | ||
special_tokens={}, # special tokens are handled manually | ||
) | ||
|
||
# BOS / EOS / Pad token IDs | ||
self._bos_id = special_tokens.index("<s>") | ||
self._eos_id = special_tokens.index("</s>") | ||
|
||
def encode(self, text): | ||
tokens = self.tokenizer.encode(text) | ||
tokens = [t + self.num_special_tokens for t in tokens] | ||
return tokens | ||
|
||
def decode(self, tokens): | ||
# Filter out special tokens and adjust the remaining tokens | ||
adjusted_tokens = [ | ||
t - self.num_special_tokens | ||
for t in tokens | ||
if t not in {self._bos_id, self._eos_id} and t >= self.num_special_tokens | ||
] | ||
|
||
# Decode only if there are tokens left after filtering | ||
if adjusted_tokens: | ||
return self.tokenizer.decode(adjusted_tokens) | ||
else: | ||
return "" # Return an empty string if all tokens were filtered out | ||
|
||
def batch_decode(self, ids): | ||
if isinstance(ids, np.ndarray) or torch.is_tensor(ids): | ||
ids = ids.tolist() | ||
|
||
if isinstance(ids[0], list): | ||
ids = ids[0] | ||
|
||
return self.decode(ids) | ||
|
||
@property | ||
def pad_id(self): | ||
return self._eos_id | ||
|
||
@property | ||
def bos_token_id(self): | ||
return self._bos_id | ||
|
||
@property | ||
def eos_token_id(self): | ||
return self._eos_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,14 @@ | |
# limitations under the License. | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
|
||
from omegaconf import OmegaConf | ||
from transformers import AutoTokenizer | ||
|
||
from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer | ||
from nemo.export.tiktoken_tokenizer import TiktokenTokenizer | ||
|
||
# TODO: use get_nmt_tokenizer helper below to instantiate tokenizer once environment / dependencies get stable | ||
# from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer | ||
|
@@ -43,6 +46,12 @@ def get_nmt_tokenizer(nemo_checkpoint_path: str): | |
tokenizer = SentencePieceTokenizer( | ||
model_path=os.path.join(nemo_checkpoint_path, tokenizer_cfg.model), legacy=legacy | ||
) | ||
elif library == "tiktoken": | ||
tmp_dir = tempfile.TemporaryDirectory() | ||
tmp_path = os.path.join(tmp_dir.name, "vocab.json") | ||
vocab_file = os.path.join(nemo_checkpoint_path, tokenizer_cfg.vocab_file) | ||
shutil.copy(vocab_file, tmp_path) | ||
tokenizer = TiktokenTokenizer(vocab_file=tmp_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @janekl I've changed it |
||
else: | ||
raise NotImplementedError("Currently we only support 'huggingface' and 'sentencepiece' tokenizer libraries.") | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check notice
Code scanning / CodeQL
Unused import Note