-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dataset tokenization script improvements (#106)
* Update tokenize_dataset.py to use load_dataset instead of load_validation_dataset * include tqdm * add split handling functionality * beauty fix * arg names & cosmetics * added scripts/demo_upload_in_chunks.py * Update tokenization script to upload to HF in chunks * Uncomment the repo creation * refactor, cosmetics * memory usage FIXMEs * fix memory usage issues * fix test_tokenization --------- Co-authored-by: Jett <[email protected]> Co-authored-by: Siwei Li <[email protected]>
- Loading branch information
1 parent
51a8e57
commit ad2936f
Showing
3 changed files
with
182 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,100 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
|
||
from datasets import Dataset | ||
from datasets import Dataset, Features, Value, load_dataset | ||
from huggingface_hub import HfApi | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.dataset.tokenization import tokenize_dataset | ||
from delphi.eval.utils import load_validation_dataset | ||
from delphi.dataset.tokenization import tokenize_and_upload_split | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="") | ||
parser = argparse.ArgumentParser(description="", allow_abbrev=False) | ||
|
||
parser.add_argument( | ||
"--input-dataset-name", | ||
"--in-repo-id", | ||
"-i", | ||
type=str, | ||
required=True, | ||
help="Text dataset from huggingface to tokenize", | ||
) | ||
parser.add_argument( | ||
"--output-dataset-name", | ||
"--feature", | ||
"-f", | ||
type=str, | ||
help="Name of the tokenized dataset to upload to huggingface", | ||
required=True, | ||
help="Name of the column containing text documents in the input dataset", | ||
) | ||
parser.add_argument( | ||
"--tokenizer-name", | ||
"--split", | ||
"-s", | ||
type=str, | ||
help="Name of the tokenizer from huggingface", | ||
required=True, | ||
help="Split of the dataset to be tokenized, supports slicing like 'train[:10%%]'", | ||
) | ||
parser.add_argument( | ||
"--token", | ||
"--out-repo-id", | ||
"-o", | ||
type=str, | ||
help="Hugging Face API token", | ||
required=True, | ||
help="Name of the tokenized dataset to upload to huggingface", | ||
) | ||
parser.add_argument( | ||
"--context-size", | ||
"--tokenizer", | ||
"-r", | ||
type=str, | ||
required=True, | ||
help="Name of the tokenizer from huggingface", | ||
) | ||
parser.add_argument( | ||
"--seq-len", | ||
"-l", | ||
type=int, | ||
default=512, | ||
required=True, | ||
help="Context size of the tokenized dataset as input of the model", | ||
) | ||
parser.add_argument( | ||
"--hf-token", | ||
"-t", | ||
type=str, | ||
help="Hugging Face API token", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
"-b", | ||
type=int, | ||
default=50, | ||
help="Batch size of text inputs into the tokenizer", | ||
help="Size of input into batched tokenization", | ||
) | ||
parser.add_argument( | ||
"--column-name", | ||
type=str, | ||
help="Name of the column containing text documents in the input dataset", | ||
"--chunk-size", | ||
"-c", | ||
type=int, | ||
default=200_000, | ||
help="Size of the parquet chunks uploaded to HuggingFace", | ||
) | ||
args = parser.parse_args() | ||
|
||
input_dataset = load_validation_dataset(args.input_dataset_name) | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
|
||
if args.column_name: | ||
text_docs = input_dataset[args.column_name] | ||
else: | ||
if len(input_dataset.column_names) > 1: | ||
raise ValueError("There is more than one column in the specified dataset") | ||
text_docs = input_dataset[input_dataset.column_names[0]] | ||
|
||
tokenized_dataset = tokenize_dataset( | ||
text_docs, | ||
tokenizer, | ||
context_size=args.context_size, | ||
batch_size=args.batch_size, | ||
) | ||
output_dataset = Dataset.from_dict( | ||
{ | ||
"tokens": tokenized_dataset, | ||
} | ||
print(f"Loading dataset '{args.in_repo_id}'...") | ||
in_dataset_split = load_dataset( | ||
args.in_repo_id, | ||
split=args.split, | ||
features=Features({args.feature: Value("string")}), | ||
) | ||
assert isinstance(in_dataset_split, Dataset) | ||
print(f"Loading tokenizer '{args.tokenizer}'...") | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | ||
assert tokenizer.bos_token_id is not None, "Tokenizer must have a bos_token_id" | ||
assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id" | ||
|
||
output_dataset.push_to_hub( | ||
repo_id=args.output_dataset_name, | ||
private=False, | ||
token=args.token, | ||
api = HfApi(token=args.hf_token) | ||
api.create_repo(repo_id=args.out_repo_id, repo_type="dataset", exist_ok=True) | ||
tokenize_and_upload_split( | ||
dataset_split=in_dataset_split, | ||
split_name=args.split.split("[")[0], | ||
tokenizer=tokenizer, | ||
seq_len=args.seq_len, | ||
batch_size=args.batch_size, | ||
chunk_size=args.chunk_size, | ||
out_repo_id=args.out_repo_id, | ||
api=api, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.