Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix normalization order in find_missing_characters() #105

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,14 @@ def train(
def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes: List[str]) -> List[str]:
vocab = tokenizer.get_vocab().keys()
charset = set()
for lang_code in lang_codes:
for ex in train_dataset["translation"]:
charset = charset | set(ex[lang_code])
if isinstance(tokenizer, (NllbTokenizerFast)):
charset = {self._mpn.normalize(char) for char in charset}
charset = {tokenizer.backend_tokenizer.normalizer.normalize_str(char) for char in charset}
mpn_normalize = True if isinstance(tokenizer, (NllbTokenizerFast)) else False
for ex in train_dataset["translation"]:
for lang_code in lang_codes:
ex_text = ex[lang_code]
if mpn_normalize:
ex_text = self._mpn.normalize(ex_text)
ex_text = tokenizer.backend_tokenizer.normalizer.normalize_str(ex_text)
charset = charset | set(ex_text)
charset = set(filter(None, {char.strip() for char in charset}))
missing_characters = sorted(list(charset - vocab))
return missing_characters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_update_tokenizer_missing_char() -> None:
MemoryText(
"text1",
[
_row(1, "Ḻ ḻ Ṉ"),
_row(2, "d e f"),
_row(1, "Ḻ ḻ Ṉ"),
_row(2, "d e f g"),
],
)
]
Expand Down Expand Up @@ -137,6 +137,7 @@ def test_update_tokenizer_missing_char() -> None:
finetuned_result_nochar = finetuned_engine_nochar._tokenizer.encode(
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_nochar_composite = finetuned_engine_nochar._tokenizer.encode("Ḏ is a composite character")

trainer_char = HuggingFaceNmtModelTrainer(
"hf-internal-testing/tiny-random-nllb",
Expand All @@ -156,6 +157,7 @@ def test_update_tokenizer_missing_char() -> None:
finetuned_result_char = finetuned_engine_char._tokenizer.encode(
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_char_composite = finetuned_engine_char._tokenizer.encode("Ḏ is a composite character")

assert isinstance(finetuned_engine_nochar._tokenizer, PreTrainedTokenizerFast) and isinstance(
finetuned_engine_char._tokenizer, PreTrainedTokenizerFast
Expand All @@ -171,6 +173,7 @@ def test_update_tokenizer_missing_char() -> None:
assert normalized_result_nochar2 != normalized_result_char2

assert finetuned_result_nochar != finetuned_result_char
assert finetuned_result_nochar_composite != finetuned_result_char_composite


def test_update_tokenizer_missing_char_skip() -> None:
Expand Down
Loading