-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tokenize text stories and split into batches (#55)
* Add function to tokenize text stories and split into batches * Split the tokenization function into two parts, fixing the while-loop issues * Add docstrings to the functions * Minor edits in the code, fix the test * Uses batch_encode() method to save time * Add script to upload to delphi-suite/batched-tokenized-stories * Remove the test file in tests/train to pass pytest * Update function name --------- authored-by: Siwei Li <[email protected]>
- Loading branch information
Showing
3 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
|
||
from datasets import Dataset | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.dataset.tokenization import tokenize_dataset | ||
from delphi.eval.utils import load_validation_dataset | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="") | ||
|
||
parser.add_argument( | ||
"--input-dataset-name", | ||
type=str, | ||
help="Text dataset from huggingface to tokenize", | ||
) | ||
parser.add_argument( | ||
"--output-dataset-name", | ||
type=str, | ||
help="Name of the tokenized dataset to upload to huggingface", | ||
) | ||
parser.add_argument( | ||
"--tokenizer-name", | ||
type=str, | ||
help="Name of the tokenizer from huggingface", | ||
) | ||
parser.add_argument( | ||
"--token", | ||
type=str, | ||
help="Hugging Face API token", | ||
) | ||
parser.add_argument( | ||
"--context-size", | ||
type=int, | ||
default=512, | ||
help="Context size of the tokenized dataset as input of the model", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=50, | ||
help="Batch size of text inputs into the tokenizer", | ||
) | ||
parser.add_argument( | ||
"--column-name", | ||
type=str, | ||
help="Name of the column containing text documents in the input dataset", | ||
) | ||
args = parser.parse_args() | ||
|
||
input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}") | ||
tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}") | ||
|
||
if args.column_name: | ||
text_docs = input_dataset[args.column_name] | ||
else: | ||
if len(input_dataset.column_names) > 1: | ||
raise ValueError("There are more than one column in the specified dataset") | ||
text_docs = input_dataset[input_dataset.column_names[0]] | ||
|
||
output_dataset = Dataset.from_dict( | ||
{ | ||
"tokens": tokenize_dataset( | ||
text_docs, | ||
tokenizer, | ||
context_size=args.context_size, | ||
batch_size=args.batch_size, | ||
) | ||
} | ||
) | ||
|
||
output_dataset.push_to_hub( | ||
repo_id=f"delphi-suite/{args.output_dataset_name}", | ||
private=False, | ||
token=args.token, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from collections import deque | ||
from typing import Optional | ||
|
||
from transformers import PreTrainedTokenizerBase | ||
|
||
|
||
def extend_deque( | ||
dq: deque[int], | ||
context_size: int, | ||
text_documents: list[str], | ||
doc_idx: int, | ||
tokenizer: PreTrainedTokenizerBase, | ||
batch_size: int, | ||
) -> int: | ||
""" | ||
Extends the deque with tokenized text documents until the deque grows large | ||
enough to reach the context size, or until all text documents are processed. | ||
The usage of a deque here aims to save the memory as opposed to | ||
load all the documents and tokenize them at once. | ||
Args: | ||
dq: Deque to extend with tokenized tokens. | ||
context_size: Size of the context(input sequences). | ||
text_documents: List of (untokenized) text documents to be tokenized. | ||
doc_idx: Index of the current text story. | ||
tokenizer: Tokenizer to encode the text strings. | ||
Returns: | ||
int: Updated index in the text documents dataset. | ||
""" | ||
while len(dq) < context_size and doc_idx < len(text_documents): | ||
text_doc = text_documents[doc_idx : doc_idx + batch_size] | ||
batch_input_ids = tokenizer( | ||
text_doc, return_attention_mask=False, add_special_tokens=False | ||
)["input_ids"] | ||
for input_ids in batch_input_ids: | ||
dq.extend(input_ids + [tokenizer.eos_token_id]) | ||
doc_idx += batch_size | ||
return doc_idx | ||
|
||
|
||
def make_new_samples( | ||
dq: deque[int], context_size: int, bos_token_id: int | ||
) -> list[list[int]]: | ||
""" | ||
Generates new samples for training by creating sequences of tokens | ||
from the deque until the deque does not hold enough tokens to generate | ||
another sample. | ||
Note: the model is unable to use the last token in an input sequence, | ||
so we repeat this token in the next input sequence. | ||
Args: | ||
dq: Deque containing tokenized tokens. | ||
context_size: Size of the context (input sequences). | ||
bos_token_id: bos_token_id of the tokenizer used. | ||
Returns: | ||
list[list[int]]: List of token sequences of the same length(context_size). | ||
""" | ||
|
||
samples = [] | ||
while len(dq) >= context_size: | ||
sample = [bos_token_id] | ||
|
||
# For the first (n-1) elements, pop from the left of the deque | ||
# and add to the new sample, the n-th element will be retained | ||
# in the deque for making the next sample. | ||
for _ in range(context_size - 1): | ||
sample.append(dq.popleft()) | ||
sample.append(dq[0]) | ||
|
||
samples.append(sample) | ||
return samples | ||
|
||
|
||
def tokenize_dataset( | ||
text_documents: list[str], | ||
tokenizer: PreTrainedTokenizerBase, | ||
context_size: int, | ||
batch_size: int, | ||
) -> list[list[int]]: | ||
""" | ||
Tokenizes the input text documents using the provided tokenizer and | ||
generates token sequences of the specified length. | ||
Args: | ||
text_documents: List[str], | ||
tokenizer, | ||
context_size, | ||
Returns: | ||
list[list[int]]: List of token sequences of length equal to context_size. | ||
""" | ||
|
||
dq = deque() | ||
doc_idx = 0 | ||
samples = [] | ||
|
||
while doc_idx < len(text_documents): | ||
doc_idx = extend_deque( | ||
dq, context_size, text_documents, doc_idx, tokenizer, batch_size | ||
) | ||
samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id)) | ||
|
||
# We discard the last chunk, so no processing on the remainder of the deque here | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import collections | ||
import random | ||
|
||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") | ||
|
||
|
||
def test_extend_deque(tokenizer): | ||
CTX_SIZE = 10 | ||
BATCH_SIZE = 2 | ||
# generate 100 random stories | ||
text_stories = [ | ||
" ".join( | ||
[ | ||
tokenizer.decode(random.randint(3, tokenizer.vocab_size)) | ||
for _ in range(random.randint(100, 800)) | ||
] | ||
) | ||
for _ in range(100) | ||
] | ||
prompt_idx = 0 | ||
dq = collections.deque() | ||
|
||
while prompt_idx < len(text_stories): | ||
prompt_idx = extend_deque( | ||
dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE | ||
) | ||
if prompt_idx < len(text_stories) - 1: | ||
# assert that the deque has grown large enough in each round | ||
assert len(dq) >= CTX_SIZE | ||
while len(dq) >= CTX_SIZE: | ||
for _ in range(CTX_SIZE - 1): | ||
dq.popleft() | ||
|
||
|
||
def test_make_new_sample(tokenizer): | ||
for _ in range(100): | ||
total_tokens = random.randint(100, 1000) | ||
context_size = random.randint(5, total_tokens // 2) | ||
dq = collections.deque(random.choices(range(3, 1000), k=total_tokens)) | ||
samples = make_new_samples(dq, context_size, tokenizer.bos_token_id) | ||
tokens_cnt = 0 | ||
for i, sample in enumerate(samples): | ||
assert sample[0] == tokenizer.bos_token_id | ||
if i > 0: | ||
# assert that there is an overlap of the last token in the previous sample | ||
# and the first token in its following sample | ||
assert sample[1] == samples[i - 1][-1] | ||
tokens_cnt += len(sample) | ||
|
||
# We discard the last chunk so the following lines are only for testing | ||
tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning | ||
assert tokens_cnt == total_tokens + ( | ||
2 * len(samples) + 1 | ||
) # BOS for each batch + overlapping of the last tokens in the batches | ||
assert len(dq) > 0 # always leaving at least one element in the deque | ||
|
||
|
||
def test_tokenize_dataset(tokenizer): | ||
CTX_SIZE = 10 | ||
BATCH_SIZE = 2 | ||
|
||
text_stories = [ | ||
"Once upon a", | ||
"Mother woke up alert. She put on her coat", | ||
"Once upon a time, in a small town, there was a weird", | ||
"Once upon a time, there was a", | ||
"Sara and Tom are friends. They like to play in the park.", | ||
] | ||
correct_batches = [ | ||
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], | ||
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], | ||
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], | ||
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], | ||
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], | ||
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], | ||
] | ||
assert ( | ||
tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) | ||
== correct_batches | ||
) |