Skip to content

Commit

Permalink
fix normalization order in find_missing_characters()
Browse files Browse the repository at this point in the history
  • Loading branch information
mshannon-sil committed Apr 5, 2024
1 parent bda3b54 commit 38cdb25
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ def train(
def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes: List[str]) -> List[str]:
vocab = tokenizer.get_vocab().keys()
charset = set()
for lang_code in lang_codes:
for ex in train_dataset["translation"]:
charset = charset | set(ex[lang_code])
if isinstance(tokenizer, (NllbTokenizerFast)):
charset = {self._mpn.normalize(char) for char in charset}
charset = {tokenizer.backend_tokenizer.normalizer.normalize_str(char) for char in charset}
for ex in train_dataset["translation"]:
for lang_code in lang_codes:
ex_text = ex[lang_code]
if isinstance(tokenizer, (NllbTokenizerFast)):
ex_text = self._mpn.normalize(ex_text)
ex_text = tokenizer.backend_tokenizer.normalizer.normalize_str(ex_text)
charset = charset | set(ex_text)
charset = set(filter(None, {char.strip() for char in charset}))
missing_characters = sorted(list(charset - vocab))
return missing_characters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_update_tokenizer_missing_char() -> None:
MemoryText(
"text1",
[
_row(1, "Ḻ ḻ Ṉ"),
_row(2, "d e f"),
_row(1, "Ḻ ḻ Ṉ"),
_row(2, "d e f g"),
],
)
]
Expand Down Expand Up @@ -137,6 +137,7 @@ def test_update_tokenizer_missing_char() -> None:
finetuned_result_nochar = finetuned_engine_nochar._tokenizer.encode(
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_nochar_composite = finetuned_engine_nochar._tokenizer.encode("Ḏ is a composite character")

trainer_char = HuggingFaceNmtModelTrainer(
"hf-internal-testing/tiny-random-nllb",
Expand All @@ -156,6 +157,7 @@ def test_update_tokenizer_missing_char() -> None:
finetuned_result_char = finetuned_engine_char._tokenizer.encode(
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_char_composite = finetuned_engine_char._tokenizer.encode("Ḏ is a composite character")

assert isinstance(finetuned_engine_nochar._tokenizer, PreTrainedTokenizerFast) and isinstance(
finetuned_engine_char._tokenizer, PreTrainedTokenizerFast
Expand All @@ -171,6 +173,7 @@ def test_update_tokenizer_missing_char() -> None:
assert normalized_result_nochar2 != normalized_result_char2

assert finetuned_result_nochar != finetuned_result_char
assert finetuned_result_nochar_composite != finetuned_result_char_composite


def test_update_tokenizer_missing_char_skip() -> None:
Expand Down

0 comments on commit 38cdb25

Please sign in to comment.