Skip to content

Commit

Permalink
bug fixes and test script
Browse files Browse the repository at this point in the history
  • Loading branch information
jannik-brinkmann committed Apr 9, 2024
1 parent c39a156 commit 6385892
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
30 changes: 13 additions & 17 deletions scripts/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from transformers import LlamaTokenizerFast
from tokenizers import SentencePieceBPETokenizer

from delphi.train.tokenizer import train_vocab, get_tokenizer_model_path
from delphi.train.tokenizer import train_vocab


def main(
vocab_size: int,
dataset_name: str,
column: str,
train_size: float,
username: str,
repo_id: str,
token: str,
seed: int,
Expand All @@ -33,8 +32,9 @@ def main(
- repo_id: Hugging Face repository ID
- token: Hugging Face API token
"""
print("repo id", repo_id)
train_ds = load_dataset(dataset_name)["train"]
if train_size < 1.0:
if (isinstance(train_size, int) and train_size > 1) or (isinstance(train_size, float) and train_size < 1.0):
train_ds = train_ds.train_test_split(
train_size=train_size,
seed=seed
Expand Down Expand Up @@ -85,10 +85,12 @@ def main(
print("Converted tokenizer to huggingface tokenizer.")

# push tokenizer to the hub
tokenizer.push_to_hub(
repo_id=repo_id,
)
print("Pushed tokenizer to huggingface hub.")
if repo_id:
tokenizer.push_to_hub(
repo_id=repo_id,
#token=args.token,
)
print("Pushed tokenizer to huggingface hub.")


if __name__ == "__main__":
Expand Down Expand Up @@ -116,14 +118,9 @@ def main(
default=1.0,
)
parser.add_argument(
"--username",
type=str,
help="Hugging Face API username",
)
parser.add_argument(
"--repo-name",
"--repo-id",
type=str,
help="Hugging Face API username",
help="Hugging Face repository ID",
)
parser.add_argument(
"--token",
Expand All @@ -134,19 +131,18 @@ def main(
"--seed",
type=int,
help="Seed",
default=42
)
parser.add_argument(
"--test-funct", action="store_true", help="Enable test function mode"
)

args = parser.parse_args()

main(
args.vocab_size,
args.dataset_name,
args.train_size,
args.column,
args.username,
args.train_size,
args.repo_id,
args.token,
args.seed,
Expand Down
7 changes: 2 additions & 5 deletions src/delphi/train/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ def train_vocab(
vocab_size: int,
dataset: Dataset,
column: str,
cache_dir: str = "cache"
) -> None:
"""
Trains a custom SentencePiece tokenizer.
"""
assert vocab_size > 0, "Vocab size must be positive"

if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

with tempfile.NamedTemporaryFile(mode='w+', suffix='.json') as tmpfile:

# export text as a single text file
Expand All @@ -30,7 +26,8 @@ def train_vocab(
print(f"Size is: {os.path.getsize(tmpfile.name) / 1024 / 1024:.2f} MB")

# train the tokenizer
prefix = os.path.join(cache_dir, f"tok{vocab_size}")
prefix = f"tok{vocab_size}"

spm.SentencePieceTrainer.train(
input=tmpfile.name,
model_prefix=prefix,
Expand Down
18 changes: 18 additions & 0 deletions tests/scripts/functional_test_train_tokenizer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
#test to check if whether train_tokenizer.py works.

VOCAB_SIZE=4096
DATASET_NAME="delphi-suite/stories"
COLUMN="story" # Your Hugging Face username

# Train the tokenizer
python3 scripts/train_tokenizer.py \
--vocab-size "$VOCAB_SIZE" \
--dataset-name "$DATASET_NAME" \
--column "$COLUMN" \

# Check if local file exists
TOKENIZER_MODEL_PATH="./tok${VOCAB_SIZE}.model"
if test -f TOKENIZER_MODEL_PATH; then
echo "Tokenizer trained."
fi

0 comments on commit 6385892

Please sign in to comment.