Skip to content

Commit

Permalink
tokenizer training script (#103)
Browse files Browse the repository at this point in the history
* draft for tokenizer training script

* Update scripts/train_tokenizer.py

Co-authored-by: Jett <[email protected]>

* integrate some of the suggested changes (WIP)

* another use of tempfile (untested)

* bug fixes and test script

* split the code, reduced num. tmp files, ...

* removed tests/scripts

* updated transformers to 4.40.0

---------

Co-authored-by: Jannik Brinkmann <[email protected]>
Co-authored-by: Jannik Brinkmann <[email protected]>
  • Loading branch information
3 people authored Apr 19, 2024
1 parent cc7a6c8 commit 51a8e57
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"dacite==1.8.1",
"panel==1.4.0",
"jupyter_bokeh==4.0.1",
"transformers==4.39.2",
"transformers==4.40.0",
]

[project.optional-dependencies]
Expand Down
97 changes: 97 additions & 0 deletions scripts/train_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
import argparse
import tempfile

from delphi.train.tokenizer import (
hf_bpe_tokenizer_to_llama_tokenizer,
hf_dataset_to_text,
sp_processor_to_hf_bpe_tokenizer,
train_sentence_piece,
)


def main(
*,
vocab_size: int,
dataset_name: str,
split: str,
column: str,
repo_id: str,
hf_token: str,
):
"""Trains a SentencePiece tokenizer, converts it to LlamaTokenizerFast and pushes it to the Hugging Face Hub."""
with tempfile.TemporaryFile(mode="w+") as text_file:
print("Loading and writing dataset to text file...")
hf_dataset_to_text(
dataset_name=dataset_name,
split=split,
column=column,
text_file=text_file,
)
text_file.seek(0)
print("Training SentencePiece tokenizer...\n")
sp_processor = train_sentence_piece(
vocab_size=vocab_size,
sentence_iterator=text_file,
)
print("\nConverting SentencePiece tokenizer Llama tokenizer...")
hf_bpe_tokenizer = sp_processor_to_hf_bpe_tokenizer(sp_processor)
llama_tokenizer = hf_bpe_tokenizer_to_llama_tokenizer(hf_bpe_tokenizer)
print("Pushing Llama tokenizer to Hugging Face Hub...")
llama_tokenizer.push_to_hub(
repo_id=repo_id,
token=hf_token,
)
print("Done.")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a SentencePiece tokenizer and convert to HF"
)
parser.add_argument(
"--vocab-size",
"-v",
type=int,
help="Vocabulary size of the tokenizer",
)
parser.add_argument(
"--dataset-name",
"-d",
type=str,
help="Dataset name with or without delphi-suite/ prefix",
)
parser.add_argument(
"--split",
"-s",
type=str,
default="train",
help="Split of the dataset to be used for training, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--column",
"-c",
type=str,
help="Column of the dataset to be used for training",
)
parser.add_argument(
"--repo-id",
"-r",
type=str,
help="Hugging Face repository ID",
)
parser.add_argument(
"--hf-token",
"-t",
type=str,
help="Hugging Face API token",
)
args = parser.parse_args()
main(
vocab_size=args.vocab_size,
dataset_name=args.dataset_name,
split=args.split,
column=args.column,
repo_id=args.repo_id,
hf_token=args.hf_token,
)
81 changes: 81 additions & 0 deletions src/delphi/train/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import io
import os
import tempfile
from typing import cast

from datasets import Dataset, load_dataset
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
from tokenizers import SentencePieceBPETokenizer # type: ignore
from transformers import LlamaTokenizerFast


def hf_dataset_to_text(
dataset_name: str, split: str, column: str, text_file: io.TextIOBase
):
dataset = cast(Dataset, load_dataset(dataset_name, split=split))
for text in dataset[column]:
text = text.strip()
text_file.write(text + "\n")


def train_sentence_piece(
vocab_size: int,
sentence_iterator: io.TextIOBase,
) -> SentencePieceProcessor:
"""Trains a custom SentencePiece tokenizer."""
model = io.BytesIO()
SentencePieceTrainer.train( # type: ignore
sentence_iterator=sentence_iterator,
model_writer=model,
model_type="bpe",
vocab_size=vocab_size,
self_test_sample_size=0,
character_coverage=1.0,
num_threads=os.cpu_count(),
split_digits=True,
allow_whitespace_only_pieces=True,
byte_fallback=True,
unk_surface=r" \342\201\207 ",
normalization_rule_name="identity",
)
return SentencePieceProcessor(model_proto=model.getvalue()) # type: ignore


def sp_processor_to_hf_bpe_tokenizer(
sp_processor: SentencePieceProcessor,
) -> SentencePieceBPETokenizer:
"""Converts a SentencePieceProcessor to a SentencePieceBPETokenizer."""
vocab = {
sp_processor.id_to_piece(index): index # type: ignore
for index in range(sp_processor.GetPieceSize())
}
merges = []
for piece_l in vocab.keys():
for piece_r in vocab.keys():
merge = f"{piece_l}{piece_r}"
piece_id = vocab.get(merge, None)
if piece_id:
merges += [(piece_l, piece_r, piece_id)]
merges = sorted(merges, key=lambda val: val[2])
merges = [(val[0], val[1]) for val in merges]

return SentencePieceBPETokenizer(vocab, merges)


def hf_bpe_tokenizer_to_llama_tokenizer(
hf_bpe_tokenizer: SentencePieceBPETokenizer,
) -> LlamaTokenizerFast:
with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as tmp_json_file:
hf_bpe_tokenizer.save(tmp_json_file.name)
return LlamaTokenizerFast(
tokenizer_file=tmp_json_file.name,
unk_token="<unk>",
unk_token_id=0,
bos_token="<s>",
bos_token_id=1,
eos_token="</s>",
eos_token_id=2,
pad_token="<pad>",
pad_token_id=3,
padding_side="right",
)
22 changes: 0 additions & 22 deletions tests/scripts/functional_test_generate_logprobs.sh

This file was deleted.

0 comments on commit 51a8e57

Please sign in to comment.