Skip to content

Commit

Permalink
tokenizer & tokenization improvements (#136)
Browse files Browse the repository at this point in the history
* add <pad> in tokenizer training

* allow saving tokenizer locally

* allow saving tokenized dataset locally

* tokenize into seq_len instead of seq_len+1

* logging

* update tokenization test (new tokenizer with pad token)
  • Loading branch information
jettjaniak authored May 15, 2024
1 parent ab2fde2 commit 0a58825
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 69 deletions.
1 change: 1 addition & 0 deletions configs/stories/llama2/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"rms_norm_eps": 1e-06,
"bos_token_id": 0,
"eos_token_id": 1,
"pad_token_id": 2,
"tie_word_embeddings": false,
"rope_theta": 10000.0,
"rope_scaling": null,
Expand Down
1 change: 1 addition & 0 deletions configs/stories/mamba/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"layer_norm_epsilon": 1e-5,
"bos_token_id": 0,
"eos_token_id": 1,
"pad_token_id": 2,
"expand": 2,
"conv_kernel": 4,
"use_bias": false,
Expand Down
59 changes: 46 additions & 13 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#!/usr/bin/env python3
import argparse
import io
import os
from pathlib import Path

from datasets import Dataset, Features, Value, load_dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer

from delphi.dataset.tokenization import tokenize_and_upload_split
from delphi.dataset.tokenization import get_tokenized_chunks

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="", allow_abbrev=False)
Expand All @@ -31,19 +34,24 @@
required=True,
help="Split of the dataset to be tokenized, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
help="Local directory to save the resulting dataset",
)
parser.add_argument(
"--out-repo-id",
"-o",
type=str,
required=True,
help="Name of the tokenized dataset to upload to huggingface",
required=False,
help="HF repo id to upload the resulting dataset",
)
parser.add_argument(
"--tokenizer",
"-r",
"-t",
type=str,
required=True,
help="Name of the tokenizer from huggingface",
help="HF repo id or local directory containing the tokenizer",
)
parser.add_argument(
"--seq-len",
Expand All @@ -67,6 +75,9 @@
help="Size of the parquet chunks uploaded to HuggingFace",
)
args = parser.parse_args()
assert (
args.out_repo_id or args.out_dir
), "You need to provide --out-repo-id or --out-dir"

print(f"Loading dataset '{args.in_repo_id}'...")
in_dataset_split = load_dataset(
Expand All @@ -75,20 +86,42 @@
features=Features({args.feature: Value("string")}),
)
assert isinstance(in_dataset_split, Dataset)
print(f"Loading tokenizer '{args.tokenizer}'...")
print(f"Loading tokenizer from '{args.tokenizer}'...")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
assert tokenizer.bos_token_id is not None, "Tokenizer must have a bos_token_id"
assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id"

api = HfApi()
api.create_repo(repo_id=args.out_repo_id, repo_type="dataset", exist_ok=True)
tokenize_and_upload_split(
api = None
if args.out_repo_id:
api = HfApi()
api.create_repo(repo_id=args.out_repo_id, repo_type="dataset", exist_ok=True)
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)

ds_chunks_it = get_tokenized_chunks(
dataset_split=in_dataset_split,
split_name=args.split.split("[")[0],
tokenizer=tokenizer,
seq_len=args.seq_len,
batch_size=args.batch_size,
chunk_size=args.chunk_size,
out_repo_id=args.out_repo_id,
api=api,
)

print(f"Tokenizing split='{args.split}'...")
split_name = args.split.split("[")[0]
for chunk_idx, ds_chunk in enumerate(ds_chunks_it):
chunk_name = f"{split_name}-{chunk_idx:05}.parquet"
if args.out_dir:
ds_parquet_chunk = Path(args.out_dir) / chunk_name
print(f"Saving '{ds_parquet_chunk}'...")
else:
ds_parquet_chunk = io.BytesIO()
ds_chunk.to_parquet(ds_parquet_chunk)
if api:
print(f"Uploading '{chunk_name}' to '{args.out_repo_id}'...")
api.upload_file(
path_or_fileobj=ds_parquet_chunk,
path_in_repo=f"data/{chunk_name}",
repo_id=args.out_repo_id,
repo_type="dataset",
)
print(f"Done saving/uploading '{chunk_name}'")
31 changes: 23 additions & 8 deletions scripts/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from datasets import Dataset, Features, Value, load_dataset
from tokenizers import ByteLevelBPETokenizer # type: ignore
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerFast


Expand All @@ -15,14 +14,15 @@ def train_byte_level_bpe(
tokenizer.train_from_iterator(
text_generator,
vocab_size=vocab_size,
special_tokens=["<bos>", "<eos>"],
special_tokens=["<bos>", "<eos>", "<pad>"],
show_progress=True,
length=len(dataset),
)
return PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
)


Expand Down Expand Up @@ -57,14 +57,22 @@ def train_byte_level_bpe(
required=True,
help="Vocabulary size of the tokenizer",
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
help="Local directory to save the resulting tokenizer",
)
parser.add_argument(
"--out-repo-id",
"-o",
type=str,
required=True,
help="Where to push the resulting tokenizer",
required=False,
help="HF repo id to upload the resulting tokenizer",
)
args = parser.parse_args()
assert (
args.out_repo_id or args.out_dir
), "You need to provide out_repo_id or out_dir"

print(f"Loading dataset '{args.in_repo_id}'...")
in_dataset_split = load_dataset(
Expand All @@ -78,6 +86,13 @@ def train_byte_level_bpe(
feature=args.feature,
vocab_size=args.vocab_size,
)
tokenizer.push_to_hub(
repo_id=args.out_repo_id,
)
if args.out_dir:
print(f"Saving tokenizer to '{args.out_dir}' directory...")
tokenizer.save_pretrained(args.out_dir)
print("Done.")
if args.out_repo_id:
print(f"Pushing tokenizer to HF repo '{args.out_repo_id}'...")
tokenizer.push_to_hub(
repo_id=args.out_repo_id,
)
print("Done.")
53 changes: 12 additions & 41 deletions src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import io
import itertools
from collections import deque
from collections.abc import Generator
from collections.abc import Iterator

from datasets import Dataset
from huggingface_hub import HfApi
from tqdm.auto import trange
from transformers import PreTrainedTokenizerBase


Expand Down Expand Up @@ -45,7 +43,7 @@ def extend_deque(
return doc_idx


def make_new_sample(deq: deque[int], context_size: int, bos_token_id: int) -> list[int]:
def make_new_sample(deq: deque[int], seq_len: int, bos_token_id: int) -> list[int]:
"""
Generates new sample for training by creating sequence of tokens
from the deque until the deque.
Expand All @@ -62,10 +60,10 @@ def make_new_sample(deq: deque[int], context_size: int, bos_token_id: int) -> li
list[int]: token sequence.
"""
sample = [bos_token_id]
# For the first (n-1) elements, pop from the left of the deque
# and add to the new sample, the n-th element will be retained
# For the first n-2 elements, pop from the left of the deque
# and add to the new sample, the (n-1)-th element will be retained
# in the deque for making the next sample.
for _ in range(context_size - 1):
for _ in range(seq_len - 2):
sample.append(deq.popleft())
sample.append(deq[0])
return sample
Expand All @@ -76,7 +74,7 @@ def tokenize_dataset(
tokenizer: PreTrainedTokenizerBase,
seq_len: int,
batch_size: int,
) -> Generator[list[int], None, None]:
) -> Iterator[list[int]]:
"""
Tokenizes the input text documents using the provided tokenizer and
generates token sequences of the specified length.
Expand All @@ -100,45 +98,18 @@ def tokenize_dataset(
# We discard the last chunk, so no processing on the remainder of the deque here


def tokenize_and_upload_split(
def get_tokenized_chunks(
dataset_split: Dataset,
split_name: str,
tokenizer: PreTrainedTokenizerBase,
seq_len: int,
batch_size: int,
chunk_size: int,
out_repo_id: str,
api: HfApi,
):
seq_gen = tokenize_dataset(
) -> Iterator[Dataset]:
seq_it = tokenize_dataset(
dataset_split,
tokenizer,
seq_len=seq_len,
batch_size=batch_size,
)
seq_it = iter(seq_gen)
print(f"Tokenizing {split_name=}...")
chunk_idx = 0
done = False
while not done:
tokens = []
print(f"Processing chunk {chunk_idx}...")
for _ in trange(chunk_size):
try:
tokens.append(next(seq_it))
except StopIteration:
done = True
break
ds_chunk = Dataset.from_dict({"tokens": tokens})
ds_parquet_chunk = io.BytesIO()
ds_chunk.to_parquet(ds_parquet_chunk)
chunk_name = f"{split_name}-{chunk_idx:05}.parquet"
print(f"Uploading {chunk_name}...")
api.upload_file(
path_or_fileobj=ds_parquet_chunk,
path_in_repo=f"data/{chunk_name}",
repo_id=out_repo_id,
repo_type="dataset",
)
chunk_idx += 1
print("Done.")
while tokens_chunk := tuple(itertools.islice(seq_it, chunk_size)):
yield Dataset.from_dict({"tokens": tokens_chunk})
14 changes: 7 additions & 7 deletions tests/dataset/test_tokeniation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_make_new_sample(tokenizer):


def test_tokenize_dataset(tokenizer):
CTX_SIZE = 10
SEQ_LEN = 11
BATCH_SIZE = 2

documents = [
Expand All @@ -86,11 +86,11 @@ def test_tokenize_dataset(tokenizer):
feature_name = get_random_feature_name()
dataset = Dataset.from_dict({feature_name: documents})
expected = [
[0, 431, 440, 260, 1, 46, 499, 1945, 368, 3443, 15],
[0, 15, 340, 576, 355, 337, 1887, 1, 431, 440, 260],
[0, 260, 399, 13, 314, 260, 560, 1005, 13, 402, 284],
[0, 284, 260, 2606, 1, 431, 440, 260, 399, 13, 402],
[0, 402, 284, 260, 1, 1370, 268, 415, 484, 412, 15],
[0, 432, 441, 261, 1, 47, 500, 1946, 369, 3444, 16],
[0, 16, 341, 577, 356, 338, 1888, 1, 432, 441, 261],
[0, 261, 400, 14, 315, 261, 561, 1006, 14, 403, 285],
[0, 285, 261, 2607, 1, 432, 441, 261, 400, 14, 403],
[0, 403, 285, 261, 1, 1371, 269, 416, 485, 413, 16],
]
actual = [x for x in tokenize_dataset(dataset, tokenizer, CTX_SIZE, BATCH_SIZE)]
actual = [x for x in tokenize_dataset(dataset, tokenizer, SEQ_LEN, BATCH_SIZE)]
assert actual == expected

0 comments on commit 0a58825

Please sign in to comment.