Skip to content

Commit

Permalink
CI checks
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Feb 12, 2024
1 parent fb2e025 commit d0cbe17
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 100 deletions.
4 changes: 2 additions & 2 deletions adaptor/objectives/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import itertools
import random
from typing import List, Tuple, Optional, Iterator
from typing import List, Tuple, Optional, Iterable

from transformers import BatchEncoding

Expand Down Expand Up @@ -147,7 +147,7 @@ def _apply_noise(self, text: str) -> str:
out_text = noising_fn(out_text, self.noising_per_sentence)
return out_text

def _get_inputs_iterator(self, split: str) -> Iterator[BatchEncoding]:
def _get_inputs_iterator(self, split: str) -> Iterable[BatchEncoding]:
"""
Generates labels by applying selected noising strategies on inputs.
:param split: Data split. `train` or `eval`.
Expand Down
7 changes: 4 additions & 3 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ def compute_loss_on_last_sample(self) -> torch.FloatTensor:

logger.warning("Computing model output")
model_inputs = {k: v for k, v in self.last_input.items() if k not in ("oid", "labels")}
logits = self.compatible_head_model(**model_inputs).logits
outputs = self.compatible_head_model(**model_inputs)
logger.warning("Model outputs computation on the recent sample successful. Outputs: %s", outputs)

logger.warning("Computing loss")
loss = self._compute_loss(logits, labels, self.last_input)
loss = self._compute_loss(self.last_input, labels)

logger.warning("Loss computation on the recent sample successful. Loss value: %s", loss.item())
return loss
Expand Down Expand Up @@ -519,7 +520,7 @@ def register_compatible_head_model(self, lang_module: LangModule,
return super().register_compatible_head_model(lang_module, other_objective,
objective_args_for_head_config, preloaded_module)

def _get_inputs_iterator(self, split: str) -> Iterator[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
"""
Batches and encodes input texts and corresponding labels.
:param split: Selected data split. `train` or `eval`.
Expand Down
2 changes: 1 addition & 1 deletion adaptor/objectives/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_seq2seq_collated_iterator(self,
# yield last nonempty residual batch
yield self.collator(features_batch)

def _get_inputs_iterator(self, split: str) -> Iterator[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
"""
Creates a default iterator over encodings with aligned input and output texts.
:param split: Data split. `train` or `eval`.
Expand Down
93 changes: 0 additions & 93 deletions tests/distillation_test.py

This file was deleted.

3 changes: 2 additions & 1 deletion tests/evaluators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def assert_evaluator_logs(objective: Objective, split: str) -> None:
dataset_sample = next(iter(objective.get_dataset(split, objective_i=0, device="cpu")))

# request objective for its loss
loss = objective.compute_loss(dataset_sample, dataset_sample["labels"], split)
loss = objective.compute_loss({k: v for k, v in dataset_sample.items() if k not in ("oid",)},
dataset_sample["labels"], split)
assert loss.item()

log = objective.per_objective_log(split)
Expand Down
90 changes: 90 additions & 0 deletions tests/objectives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.objectives.backtranslation import BackTranslation, BackTranslator
from adaptor.objectives.classification import TokenClassification
from adaptor.objectives.distillation import Distillation
from adaptor.objectives.denoising import DenoisingObjective
from adaptor.objectives.objective_base import Objective
from adaptor.objectives.question_answering import ExtractiveQA
Expand Down Expand Up @@ -123,6 +124,83 @@ def test_supervised_seq2seq_objective_mbart():
assert_module_objective_ok(lang_module, objective)


def test_distillation_seq():
from adaptor.objectives.seq2seq import Sequence2Sequence
from transformers import AutoModelForSeq2SeqLM

class DistilledSeq2Seq(Distillation, Sequence2Sequence):
# this is a full implementation of distillation within other objective
pass

lang_module = LangModule(test_base_models["translation_mono"])
distilled_model = AutoModelForSeq2SeqLM.from_pretrained(test_base_models["translation_mono"])

objective = DistilledSeq2Seq(lang_module,
teacher_model=distilled_model,
texts_or_path=paths["texts"]["translation"],
labels_or_path=paths["labels"]["translation"],
batch_size=4)

assert_module_objective_ok(lang_module, objective)


def test_distillation_mlm():
from adaptor.objectives.MLM import MaskedLanguageModeling
from transformers import AutoModelForMaskedLM

class DistilledMLM(Distillation, MaskedLanguageModeling):
pass

lang_module = LangModule(test_base_models["MLM_student"])
distilled_model = AutoModelForMaskedLM.from_pretrained(test_base_models["MLM"])

objective = DistilledMLM(lang_module,
teacher_model=distilled_model,
texts_or_path=paths["texts"]["unsup"],
batch_size=4)

assert_module_objective_ok(lang_module, objective)


def test_distillation_mlm_incl_hidden_states():
from adaptor.objectives.MLM import MaskedLanguageModeling
from transformers import AutoModelForMaskedLM

class DistilledMLM(Distillation, MaskedLanguageModeling):
pass

lang_module = LangModule(test_base_models["MLM_student"])
distilled_model = AutoModelForMaskedLM.from_pretrained(test_base_models["MLM"])

objective = DistilledMLM(lang_module,
teacher_model=distilled_model,
add_hidden_states_loss=True,
texts_or_path=paths["texts"]["unsup"],
batch_size=4)

assert_module_objective_ok(lang_module, objective)


def test_distillation_mlm_restrict_to_attention():
from adaptor.objectives.MLM import MaskedLanguageModeling
from transformers import AutoModelForMaskedLM

class DistilledMLM(Distillation, MaskedLanguageModeling):
pass

lang_module = LangModule(test_base_models["MLM_student"])
distilled_model = AutoModelForMaskedLM.from_pretrained(test_base_models["MLM"])

objective = DistilledMLM(lang_module,
teacher_model=distilled_model,
add_hidden_states_loss=True,
restrict_loss_to_mask=True,
texts_or_path=paths["texts"]["unsup"],
batch_size=4)

assert_module_objective_ok(lang_module, objective)


def test_supervised_QA_objective():
lang_module = LangModule(test_base_models["extractive_QA"])

Expand All @@ -133,3 +211,15 @@ def test_supervised_QA_objective():
batch_size=4)

assert_module_objective_ok(lang_module, objective)


# def test_search_objective():
# lang_module = LangModule(test_base_models["extractive_QA"])
#
# objective = Encoding(lang_module,
# texts_or_path=paths["texts"]["QA"],
# text_pair_or_path=paths["text_pair"]["QA"],
# labels_or_path=paths["labels"]["QA"],
# batch_size=4)
#
# assert_module_objective_ok(lang_module, objective)

0 comments on commit d0cbe17

Please sign in to comment.