Skip to content

Commit

Permalink
tokenize text stories and split into batches (#55)
Browse files Browse the repository at this point in the history
* Add function to tokenize text stories and split into batches

* Split the tokenization function into two parts, fixing the while-loop issues

* Add docstrings to the functions

* Minor edits in the code, fix the test

* Uses batch_encode() method to save time

* Add script to upload to delphi-suite/batched-tokenized-stories

* Remove the test file in tests/train to pass pytest

* Update function name

---------

authored-by: Siwei Li <[email protected]>
  • Loading branch information
siwei-li authored Mar 30, 2024
1 parent 01eb277 commit 5b7ec89
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 0 deletions.
78 changes: 78 additions & 0 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

import argparse

from datasets import Dataset
from transformers import AutoTokenizer

from delphi.dataset.tokenization import tokenize_dataset
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")

parser.add_argument(
"--input-dataset-name",
type=str,
help="Text dataset from huggingface to tokenize",
)
parser.add_argument(
"--output-dataset-name",
type=str,
help="Name of the tokenized dataset to upload to huggingface",
)
parser.add_argument(
"--tokenizer-name",
type=str,
help="Name of the tokenizer from huggingface",
)
parser.add_argument(
"--token",
type=str,
help="Hugging Face API token",
)
parser.add_argument(
"--context-size",
type=int,
default=512,
help="Context size of the tokenized dataset as input of the model",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Batch size of text inputs into the tokenizer",
)
parser.add_argument(
"--column-name",
type=str,
help="Name of the column containing text documents in the input dataset",
)
args = parser.parse_args()

input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}")
tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}")

if args.column_name:
text_docs = input_dataset[args.column_name]
else:
if len(input_dataset.column_names) > 1:
raise ValueError("There are more than one column in the specified dataset")
text_docs = input_dataset[input_dataset.column_names[0]]

output_dataset = Dataset.from_dict(
{
"tokens": tokenize_dataset(
text_docs,
tokenizer,
context_size=args.context_size,
batch_size=args.batch_size,
)
}
)

output_dataset.push_to_hub(
repo_id=f"delphi-suite/{args.output_dataset_name}",
private=False,
token=args.token,
)
107 changes: 107 additions & 0 deletions src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from collections import deque
from typing import Optional

from transformers import PreTrainedTokenizerBase


def extend_deque(
dq: deque[int],
context_size: int,
text_documents: list[str],
doc_idx: int,
tokenizer: PreTrainedTokenizerBase,
batch_size: int,
) -> int:
"""
Extends the deque with tokenized text documents until the deque grows large
enough to reach the context size, or until all text documents are processed.
The usage of a deque here aims to save the memory as opposed to
load all the documents and tokenize them at once.
Args:
dq: Deque to extend with tokenized tokens.
context_size: Size of the context(input sequences).
text_documents: List of (untokenized) text documents to be tokenized.
doc_idx: Index of the current text story.
tokenizer: Tokenizer to encode the text strings.
Returns:
int: Updated index in the text documents dataset.
"""
while len(dq) < context_size and doc_idx < len(text_documents):
text_doc = text_documents[doc_idx : doc_idx + batch_size]
batch_input_ids = tokenizer(
text_doc, return_attention_mask=False, add_special_tokens=False
)["input_ids"]
for input_ids in batch_input_ids:
dq.extend(input_ids + [tokenizer.eos_token_id])
doc_idx += batch_size
return doc_idx


def make_new_samples(
dq: deque[int], context_size: int, bos_token_id: int
) -> list[list[int]]:
"""
Generates new samples for training by creating sequences of tokens
from the deque until the deque does not hold enough tokens to generate
another sample.
Note: the model is unable to use the last token in an input sequence,
so we repeat this token in the next input sequence.
Args:
dq: Deque containing tokenized tokens.
context_size: Size of the context (input sequences).
bos_token_id: bos_token_id of the tokenizer used.
Returns:
list[list[int]]: List of token sequences of the same length(context_size).
"""

samples = []
while len(dq) >= context_size:
sample = [bos_token_id]

# For the first (n-1) elements, pop from the left of the deque
# and add to the new sample, the n-th element will be retained
# in the deque for making the next sample.
for _ in range(context_size - 1):
sample.append(dq.popleft())
sample.append(dq[0])

samples.append(sample)
return samples


def tokenize_dataset(
text_documents: list[str],
tokenizer: PreTrainedTokenizerBase,
context_size: int,
batch_size: int,
) -> list[list[int]]:
"""
Tokenizes the input text documents using the provided tokenizer and
generates token sequences of the specified length.
Args:
text_documents: List[str],
tokenizer,
context_size,
Returns:
list[list[int]]: List of token sequences of length equal to context_size.
"""

dq = deque()
doc_idx = 0
samples = []

while doc_idx < len(text_documents):
doc_idx = extend_deque(
dq, context_size, text_documents, doc_idx, tokenizer, batch_size
)
samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id))

# We discard the last chunk, so no processing on the remainder of the deque here
return samples
88 changes: 88 additions & 0 deletions tests/dataset/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import collections
import random

import pytest
from transformers import AutoTokenizer

from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset


@pytest.fixture
def tokenizer():
return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")


def test_extend_deque(tokenizer):
CTX_SIZE = 10
BATCH_SIZE = 2
# generate 100 random stories
text_stories = [
" ".join(
[
tokenizer.decode(random.randint(3, tokenizer.vocab_size))
for _ in range(random.randint(100, 800))
]
)
for _ in range(100)
]
prompt_idx = 0
dq = collections.deque()

while prompt_idx < len(text_stories):
prompt_idx = extend_deque(
dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE
)
if prompt_idx < len(text_stories) - 1:
# assert that the deque has grown large enough in each round
assert len(dq) >= CTX_SIZE
while len(dq) >= CTX_SIZE:
for _ in range(CTX_SIZE - 1):
dq.popleft()


def test_make_new_sample(tokenizer):
for _ in range(100):
total_tokens = random.randint(100, 1000)
context_size = random.randint(5, total_tokens // 2)
dq = collections.deque(random.choices(range(3, 1000), k=total_tokens))
samples = make_new_samples(dq, context_size, tokenizer.bos_token_id)
tokens_cnt = 0
for i, sample in enumerate(samples):
assert sample[0] == tokenizer.bos_token_id
if i > 0:
# assert that there is an overlap of the last token in the previous sample
# and the first token in its following sample
assert sample[1] == samples[i - 1][-1]
tokens_cnt += len(sample)

# We discard the last chunk so the following lines are only for testing
tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning
assert tokens_cnt == total_tokens + (
2 * len(samples) + 1
) # BOS for each batch + overlapping of the last tokens in the batches
assert len(dq) > 0 # always leaving at least one element in the deque


def test_tokenize_dataset(tokenizer):
CTX_SIZE = 10
BATCH_SIZE = 2

text_stories = [
"Once upon a",
"Mother woke up alert. She put on her coat",
"Once upon a time, in a small town, there was a weird",
"Once upon a time, there was a",
"Sara and Tom are friends. They like to play in the park.",
]
correct_batches = [
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037],
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261],
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286],
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406],
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037],
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2],
]
assert (
tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE)
== correct_batches
)

0 comments on commit 5b7ec89

Please sign in to comment.