Skip to content

Commit

Permalink
Checks for language_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Apr 16, 2024
1 parent a873644 commit 4c272d3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
6 changes: 6 additions & 0 deletions adaptor/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def __init__(self, lang_module: LangModule, schedule: Schedule, args: Adaptation

orig_callbacks = [] if "callbacks" not in kwargs else kwargs.pop("callbacks")

all_objectives_ids = list(map(str, self.schedule.objectives["train"].values()))
if len(set(all_objectives_ids)) < len(all_objectives_ids):
duplicates = [identifier for identifier in all_objectives_ids if all_objectives_ids.count(identifier) > 1]
raise ValueError("These objectives have identical identifiers: %s; This would cause "
"incorrect persistence of checkpoints for your objectives." % set(duplicates))

super().__init__(model=lang_module,
args=args,
train_dataset=self.schedule.iterable_dataset(split="train"),
Expand Down
4 changes: 2 additions & 2 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self,
max_samples_per_log: int = 1000,
max_samples_per_eval_log: int = 10000,
data_iteration_offset: int = 0,
prefetch_in_parallel_thread: bool = True,
prefetch_in_parallel_thread: bool = False,
remember_last_input: Optional[bool] = False):
"""
Shared initialisation logic of every Objective.
Expand Down Expand Up @@ -165,7 +165,7 @@ def _compute_data_source_length(self, texts_or_path: Union[str, List[str]]) -> i
with io.TextIOWrapper(io.BufferedReader(gzip.open(texts_or_path))) as f:
return sum(1 for _ in f) # more efficient line count
else:
with open(self.texts_path, "rb") as f:
with open(texts_or_path, "rb") as f:
return sum(1 for _ in f) # more efficient line count

elif isinstance(texts_or_path, list):
Expand Down
23 changes: 15 additions & 8 deletions adaptor/objectives/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import abc
from typing import List, Optional, Iterable, Dict, Iterator, Callable, Any, Union
from typing import List, Optional, Iterable, Dict, Iterator, Callable, Union

import torch
from transformers import DataCollatorForSeq2Seq, BatchEncoding

from ..lang_module import LangModule
from ..objectives.objective_base import SupervisedObjective, Objective
from ..utils import Head

Expand All @@ -23,6 +22,18 @@ def __init__(self, *args,
self.target_lang_id = target_lang_id
super().__init__(*args, **kwargs)

if hasattr(self.tokenizer, "lang_code_to_id") and self.source_lang_id is not None:
assert self.source_lang_id in self.tokenizer.vocab, \
("Objective %s's 'src_lang' is not in its tokenizer's vocabulary. "
"This would cause wrong data encodings." % self.source_lang_id)
self.tokenizer.src_lang = self.source_lang_id

if hasattr(self.tokenizer, "lang_code_to_id") and self.target_lang_id is not None:
assert self.target_lang_id in self.tokenizer.vocab, \
("Objective %s's 'tgt_lang' is not in its tokenizer's vocabulary. "
"This would cause wrong data encodings." % self.tokenizer.tgt_lang)
self.tokenizer.tgt_lang = self.target_lang_id

def _get_seq2seq_collated_iterator(self,
source_texts: Iterable[str],
target_texts: Iterable[str]) -> Iterator[BatchEncoding]:
Expand All @@ -37,13 +48,9 @@ def _get_seq2seq_collated_iterator(self,
for source_text, target_text in zip(source_texts, target_texts):
self.tokenizer.src_lang = self.source_lang_id
self.tokenizer.tgt_lang = self.target_lang_id
sample_features = self.tokenizer(source_text, truncation=True)
sample_features = dict(self.tokenizer(source_text, text_target=target_text, truncation=True))

with self.tokenizer.as_target_tokenizer():
sample_targets = self.tokenizer(target_text, truncation=True)
features_batch.append({"input_ids": sample_features.input_ids,
"attention_mask": sample_features.attention_mask,
"labels": sample_targets.input_ids})
features_batch.append(sample_features)
if len(features_batch) == self.batch_size:
yield self.collator(features_batch)
features_batch = []
Expand Down

0 comments on commit 4c272d3

Please sign in to comment.