-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* draft for tokenizer training script * Update scripts/train_tokenizer.py Co-authored-by: Jett <[email protected]> * integrate some of the suggested changes (WIP) * another use of tempfile (untested) * bug fixes and test script * split the code, reduced num. tmp files, ... * removed tests/scripts * updated transformers to 4.40.0 --------- Co-authored-by: Jannik Brinkmann <[email protected]> Co-authored-by: Jannik Brinkmann <[email protected]>
- Loading branch information
1 parent
cc7a6c8
commit 51a8e57
Showing
4 changed files
with
179 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
import tempfile | ||
|
||
from delphi.train.tokenizer import ( | ||
hf_bpe_tokenizer_to_llama_tokenizer, | ||
hf_dataset_to_text, | ||
sp_processor_to_hf_bpe_tokenizer, | ||
train_sentence_piece, | ||
) | ||
|
||
|
||
def main( | ||
*, | ||
vocab_size: int, | ||
dataset_name: str, | ||
split: str, | ||
column: str, | ||
repo_id: str, | ||
hf_token: str, | ||
): | ||
"""Trains a SentencePiece tokenizer, converts it to LlamaTokenizerFast and pushes it to the Hugging Face Hub.""" | ||
with tempfile.TemporaryFile(mode="w+") as text_file: | ||
print("Loading and writing dataset to text file...") | ||
hf_dataset_to_text( | ||
dataset_name=dataset_name, | ||
split=split, | ||
column=column, | ||
text_file=text_file, | ||
) | ||
text_file.seek(0) | ||
print("Training SentencePiece tokenizer...\n") | ||
sp_processor = train_sentence_piece( | ||
vocab_size=vocab_size, | ||
sentence_iterator=text_file, | ||
) | ||
print("\nConverting SentencePiece tokenizer Llama tokenizer...") | ||
hf_bpe_tokenizer = sp_processor_to_hf_bpe_tokenizer(sp_processor) | ||
llama_tokenizer = hf_bpe_tokenizer_to_llama_tokenizer(hf_bpe_tokenizer) | ||
print("Pushing Llama tokenizer to Hugging Face Hub...") | ||
llama_tokenizer.push_to_hub( | ||
repo_id=repo_id, | ||
token=hf_token, | ||
) | ||
print("Done.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Train a SentencePiece tokenizer and convert to HF" | ||
) | ||
parser.add_argument( | ||
"--vocab-size", | ||
"-v", | ||
type=int, | ||
help="Vocabulary size of the tokenizer", | ||
) | ||
parser.add_argument( | ||
"--dataset-name", | ||
"-d", | ||
type=str, | ||
help="Dataset name with or without delphi-suite/ prefix", | ||
) | ||
parser.add_argument( | ||
"--split", | ||
"-s", | ||
type=str, | ||
default="train", | ||
help="Split of the dataset to be used for training, supports slicing like 'train[:10%%]'", | ||
) | ||
parser.add_argument( | ||
"--column", | ||
"-c", | ||
type=str, | ||
help="Column of the dataset to be used for training", | ||
) | ||
parser.add_argument( | ||
"--repo-id", | ||
"-r", | ||
type=str, | ||
help="Hugging Face repository ID", | ||
) | ||
parser.add_argument( | ||
"--hf-token", | ||
"-t", | ||
type=str, | ||
help="Hugging Face API token", | ||
) | ||
args = parser.parse_args() | ||
main( | ||
vocab_size=args.vocab_size, | ||
dataset_name=args.dataset_name, | ||
split=args.split, | ||
column=args.column, | ||
repo_id=args.repo_id, | ||
hf_token=args.hf_token, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import io | ||
import os | ||
import tempfile | ||
from typing import cast | ||
|
||
from datasets import Dataset, load_dataset | ||
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer | ||
from tokenizers import SentencePieceBPETokenizer # type: ignore | ||
from transformers import LlamaTokenizerFast | ||
|
||
|
||
def hf_dataset_to_text( | ||
dataset_name: str, split: str, column: str, text_file: io.TextIOBase | ||
): | ||
dataset = cast(Dataset, load_dataset(dataset_name, split=split)) | ||
for text in dataset[column]: | ||
text = text.strip() | ||
text_file.write(text + "\n") | ||
|
||
|
||
def train_sentence_piece( | ||
vocab_size: int, | ||
sentence_iterator: io.TextIOBase, | ||
) -> SentencePieceProcessor: | ||
"""Trains a custom SentencePiece tokenizer.""" | ||
model = io.BytesIO() | ||
SentencePieceTrainer.train( # type: ignore | ||
sentence_iterator=sentence_iterator, | ||
model_writer=model, | ||
model_type="bpe", | ||
vocab_size=vocab_size, | ||
self_test_sample_size=0, | ||
character_coverage=1.0, | ||
num_threads=os.cpu_count(), | ||
split_digits=True, | ||
allow_whitespace_only_pieces=True, | ||
byte_fallback=True, | ||
unk_surface=r" \342\201\207 ", | ||
normalization_rule_name="identity", | ||
) | ||
return SentencePieceProcessor(model_proto=model.getvalue()) # type: ignore | ||
|
||
|
||
def sp_processor_to_hf_bpe_tokenizer( | ||
sp_processor: SentencePieceProcessor, | ||
) -> SentencePieceBPETokenizer: | ||
"""Converts a SentencePieceProcessor to a SentencePieceBPETokenizer.""" | ||
vocab = { | ||
sp_processor.id_to_piece(index): index # type: ignore | ||
for index in range(sp_processor.GetPieceSize()) | ||
} | ||
merges = [] | ||
for piece_l in vocab.keys(): | ||
for piece_r in vocab.keys(): | ||
merge = f"{piece_l}{piece_r}" | ||
piece_id = vocab.get(merge, None) | ||
if piece_id: | ||
merges += [(piece_l, piece_r, piece_id)] | ||
merges = sorted(merges, key=lambda val: val[2]) | ||
merges = [(val[0], val[1]) for val in merges] | ||
|
||
return SentencePieceBPETokenizer(vocab, merges) | ||
|
||
|
||
def hf_bpe_tokenizer_to_llama_tokenizer( | ||
hf_bpe_tokenizer: SentencePieceBPETokenizer, | ||
) -> LlamaTokenizerFast: | ||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as tmp_json_file: | ||
hf_bpe_tokenizer.save(tmp_json_file.name) | ||
return LlamaTokenizerFast( | ||
tokenizer_file=tmp_json_file.name, | ||
unk_token="<unk>", | ||
unk_token_id=0, | ||
bos_token="<s>", | ||
bos_token_id=1, | ||
eos_token="</s>", | ||
eos_token_id=2, | ||
pad_token="<pad>", | ||
pad_token_id=3, | ||
padding_side="right", | ||
) |
This file was deleted.
Oops, something went wrong.