Skip to content

Commit

Permalink
arg names & cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed Apr 11, 2024
1 parent 89910dd commit 495328e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 35 deletions.
57 changes: 28 additions & 29 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,34 @@
from transformers import AutoTokenizer

from delphi.dataset.tokenization import tokenize_dataset
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")

parser.add_argument(
"--input-dataset-name",
"-i",
"--input-dataset",
type=str,
help="Text dataset from huggingface to tokenize",
)
parser.add_argument(
"--output-dataset-name",
"--column-name",
type=str,
help="Name of the column containing text documents in the input dataset",
)
parser.add_argument(
"-o",
"--output-dataset",
type=str,
help="Name of the tokenized dataset to upload to huggingface",
)
parser.add_argument(
"--tokenizer-name",
"--tokenizer",
type=str,
help="Name of the tokenizer from huggingface",
)
parser.add_argument(
"--token",
"--hf-token",
type=str,
help="Hugging Face API token",
)
Expand All @@ -44,36 +50,29 @@
default=50,
help="Batch size of text inputs into the tokenizer",
)
parser.add_argument(
"--column-name",
type=str,
help="Name of the column containing text documents in the input dataset",
)
args = parser.parse_args()

input_dataset = load_dataset(args.input_dataset_name)
input_dataset = cast(Dataset, input_dataset)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
print(f"Loading dataset '{args.input_dataset}'...")
input_dataset = load_dataset(args.input_dataset)
input_dataset = cast(DatasetDict, input_dataset)
print(f"Loading tokenizer '{args.tokenizer}'...")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
assert tokenizer.bos_token_id is not None, "Tokenizer must have a bos_token_id"
assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id"

splits = list(input_dataset.keys())
tokenized_datasets = {} # dict that will hold tokenized vers. of each dataset split
print(f"{splits=}")

for i, split in enumerate(splits):
print(f"Tokenizing {split = }", flush=True)
text_docs = input_dataset[split]

if args.column_name:
text_docs = text_docs[args.column_name]
else:
if len(input_dataset.column_names) > 1:
raise ValueError(
"There is more than one column in the specified dataset"
)
print(f"Using column {input_dataset.column_names[0]}")
text_docs = input_dataset[input_dataset.column_names[0]]

assert (
args.column_name or len(text_docs.column_names) == 1
), "--column-name required when dataset has multiple columns"
column_name = args.column_name or text_docs.column_names[0]
print(f"Tokenizing {split=} {column_name=}")
tokenized_dataset = tokenize_dataset(
text_docs,
text_docs[column_name],
tokenizer,
context_size=args.context_size,
batch_size=args.batch_size,
Expand All @@ -84,12 +83,12 @@
# Create a new dataset with the same structure (splits) as the original dataset, but with tokenized data
output_dataset = DatasetDict(tokenized_datasets)

print("Tokenizaton completed. Uploading dataset to Huggingface.", flush=True)
print("Tokenizaton completed. Uploading dataset to Huggingface.")

output_dataset.push_to_hub(
repo_id=args.output_dataset_name,
repo_id=args.output_dataset,
private=False,
token=args.token,
token=args.hf_token,
)

print("Done.", flush=True)
11 changes: 5 additions & 6 deletions src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import deque
from typing import Optional

from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -34,7 +33,7 @@ def extend_deque(
batch_input_ids = tokenizer(
text_doc, return_attention_mask=False, add_special_tokens=False
)["input_ids"]
for input_ids in batch_input_ids:
for input_ids in batch_input_ids: # type: ignore
dq.extend(input_ids + [tokenizer.eos_token_id])
doc_idx += batch_size
return doc_idx
Expand Down Expand Up @@ -93,22 +92,22 @@ def tokenize_dataset(
Returns:
list[list[int]]: List of token sequences of length equal to context_size.
"""

assert tokenizer.bos_token_id is not None
dq = deque()
doc_idx = 0
samples = []

pbar = tqdm(total=len(text_documents), desc="Tokenizing text documents", leave=True)
old_idx = 0
prev_doc_idx = 0
# iterate through the text documents and tokenize them
while doc_idx < len(text_documents):
doc_idx = extend_deque(
dq, context_size, text_documents, doc_idx, tokenizer, batch_size
)
samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id))
# update the tqdm bar
pbar.update(doc_idx - old_idx)
old_idx = doc_idx
pbar.update(doc_idx - prev_doc_idx)
prev_doc_idx = doc_idx

pbar.close()

Expand Down

0 comments on commit 495328e

Please sign in to comment.