Skip to content

Commit

Permalink
another use of tempfile (untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannik-brinkmann committed Apr 5, 2024
1 parent 0fccd5a commit c39a156
Showing 1 changed file with 28 additions and 33 deletions.
61 changes: 28 additions & 33 deletions src/delphi/train/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import sentencepiece as spm
import tempfile

from datasets import Dataset


Expand All @@ -18,37 +19,31 @@ def train_vocab(
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

# export text as a single text file
text_file = os.path.join(cache_dir, "text.txt")
with open(text_file, 'w', encoding='utf-8') as file:
for item in dataset:
text = item[column]
text = text.strip()
file.write(text + '\n')
print(f"Size is: {os.path.getsize(text_file) / 1024 / 1024:.2f} MB")

# train the tokenizer
prefix = os.path.join(cache_dir, f"tok{vocab_size}")
spm.SentencePieceTrainer.train(
input=text_file,
model_prefix=prefix,
model_type="bpe",
vocab_size=vocab_size,
self_test_sample_size=0,
input_format="text",
character_coverage=1.0,
num_threads=os.cpu_count(),
split_digits=True,
allow_whitespace_only_pieces=True,
byte_fallback=True,
unk_surface=r" \342\201\207 ",
normalization_rule_name="identity"
)

# optional cleanup of the text file
dec = input(f"Delete the temporary file {text_file}? [y/N] ")
if dec.lower() == "y":
os.remove(text_file)
print(f"Deleted {text_file}")
with tempfile.NamedTemporaryFile(mode='w+', suffix='.json') as tmpfile:

# export text as a single text file
with open(tmpfile.name, 'w', encoding='utf-8') as file:
for item in dataset:
text = item[column]
text = text.strip()
file.write(text + '\n')
print(f"Size is: {os.path.getsize(tmpfile.name) / 1024 / 1024:.2f} MB")

# train the tokenizer
prefix = os.path.join(cache_dir, f"tok{vocab_size}")
spm.SentencePieceTrainer.train(
input=tmpfile.name,
model_prefix=prefix,
model_type="bpe",
vocab_size=vocab_size,
self_test_sample_size=0,
input_format="text",
character_coverage=1.0,
num_threads=os.cpu_count(),
split_digits=True,
allow_whitespace_only_pieces=True,
byte_fallback=True,
unk_surface=r" \342\201\207 ",
normalization_rule_name="identity"
)
print(f"Trained tokenizer is in {prefix}.model")

0 comments on commit c39a156

Please sign in to comment.