Skip to content

Commit

Permalink
Removed objective_id propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Feb 13, 2024
1 parent a0f47b9 commit 1a77e7c
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 53 deletions.
6 changes: 1 addition & 5 deletions adaptor/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ def compute_loss(self,
return_outputs: bool = False) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, None]]:
labels = inputs["labels"] if "labels" in inputs else inputs["label"]

if self.label_smoother is not None:
raise NotImplementedError() # implementation of label smoothing is objective-dependent
# loss = self.label_smoother(outputs, labels)
else:
loss = self.schedule.compute_loss(inputs, labels)
loss = self.schedule.compute_loss(inputs, labels)

mock_outputs = torch.tensor([-1, -1])
return (loss, mock_outputs) if return_outputs else loss
Expand Down
25 changes: 0 additions & 25 deletions adaptor/lang_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import inspect
from typing import List, Dict, Any, Optional

import torch
Expand Down Expand Up @@ -151,30 +150,6 @@ def _partially_merge_models(orig_model: torch.nn.Module,

return unmatched_modules

def forward(self, return_loss: bool = True, **inputs) -> torch.LongTensor:
"""
Performs forward pass over the head identified by the sample's `oid`.
:param inputs: given head input arguments with corresponding values.
:return: Raw model outputs (logits).
"""
try:
selected_head_model = self.trainable_models[str(inputs["oid"].item())]
except KeyError:
raise ValueError("Requesting inference with the objective having no registered head."
"If you are using `extra_eval_objectives`, "
"do not forget to fill in their `share_other_objective_head`.")
# include only correct inputs for a specific model
list_of_model_specific_inputs = inspect.getfullargspec(selected_head_model.forward).args
model_specific_inputs = {k: v for k, v in inputs.items() if k in list_of_model_specific_inputs}

# including labels cause the loss to be computed twice - by objective + by HF models forward()
# but labels are also used to infer decoder_input_ids of some models, so we need to pass it
selected_head_output = selected_head_model(**model_specific_inputs)
# HF models produce special Output objects instead of a raw output
logits = selected_head_output.logits if hasattr(selected_head_output, "logits") else selected_head_output

return logits

def reinitialize(self, seed: int = 42) -> None:
"""
Resets the trainable weights of all trainable_models.
Expand Down
23 changes: 8 additions & 15 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ def per_objective_log(self, split: str) -> Dict[str, float]:

for evaluator in self.evaluators[split]:
dataset = self.get_dataset(split, 0, self.compatible_head_model.device,
firstn=self.max_samples_per_log[split],
add_oid=False,
is_training_dataset=False)
firstn=self.max_samples_per_log[split], is_training_dataset=False)
# evaluator should already return an aggregated value, so unlike loss, we don't average it
evaluator_value = evaluator(self.compatible_head_model, self.tokenizer, dataset)
self.evaluations_history[split][evaluator].append(evaluator_value)
Expand Down Expand Up @@ -261,7 +259,6 @@ def get_dataset(self,
objective_i: int,
device: Union[str, torch.device],
firstn: Optional[int] = None,
add_oid: bool = True,
is_training_dataset: bool = True,
show_progressbar: bool = True) -> TransformerAdaptationDataset:
"""
Expand All @@ -270,8 +267,8 @@ def get_dataset(self,
:param objective_i: Rank of this objective in schedule. Used only to properly set up progress bar.
:param device: Device to transfer this data set to.
:param firstn: If given, a number of the retrieved items from the dataset.
:param add_oid: Whether to append objective id to the match. Required for forward pass over LangModule.
:param is_training_dataset: Whether this dataset is used for training -> if to update the epochs counter.
:param show_progressbar: Whether to maintain a dataset iterator progress bar for this objective.
:return: TransformerAdaptationDataset wrapping a data set of this objective.
"""
Expand All @@ -293,12 +290,8 @@ def get_dataset(self,

inputs_iter = self._get_inputs_iterator(split)

def _sample_to_device(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
return {k: v.to(device) if k != "oid" else v for k, v in sample.items()}

def _add_oid(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
sample["oid"] = torch.tensor(id(self))
return sample
def _sample_to_device(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
return {k: v.to(device) for k, v in sample.items()}

def _remember_input(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
self.last_input = sample
Expand All @@ -310,9 +303,6 @@ def _update_pbar(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> D

device_inputs_iter = map(_sample_to_device, inputs_iter)

if add_oid:
device_inputs_iter = map(_add_oid, device_inputs_iter)

if firstn is not None and firstn < self.dataset_length[split]:
device_inputs_iter = itertools.islice(device_inputs_iter, firstn)

Expand All @@ -324,6 +314,9 @@ def _update_pbar(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> D

return TransformerAdaptationDataset(device_inputs_iter, self.dataset_length[split])

def get_id(self) -> int:
return id(self)

def compute_loss_on_last_sample(self) -> torch.FloatTensor:
"""
This method aims to reproduce an error of calling the `objective.compatible_head_model`
Expand All @@ -338,7 +331,7 @@ def compute_loss_on_last_sample(self) -> torch.FloatTensor:
labels = self.last_input["labels"]

logger.warning("Computing model output")
model_inputs = {k: v for k, v in self.last_input.items() if k not in ("oid", "labels")}
model_inputs = {k: v for k, v in self.last_input.items() if k not in ("labels")}
outputs = self.compatible_head_model(**model_inputs)
logger.warning("Model outputs computation on the recent sample successful. Outputs: %s", outputs)

Expand Down
7 changes: 3 additions & 4 deletions adaptor/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def compute_loss(self,
split, oid = self.objectives_outputs_queue.pop(0)

# the objective loss arrives aggregated into a single item
inputs.pop("oid")
loss = self.objectives[split][oid.item()].compute_loss(inputs, labels, split)
loss = self.objectives[split][oid].compute_loss(inputs, labels, split)

return loss

Expand All @@ -191,7 +190,7 @@ def _one_round_eval_objective_sampler(self, objective: Objective, obj_i: int) ->
"""
dataset = objective.get_dataset("eval", obj_i, self.args.device)
for sample in dataset:
self.objectives_outputs_queue.append(("eval", sample["oid"]))
self.objectives_outputs_queue.append(("eval", objective.get_id()))
yield sample

def _infinite_train_objective_sampler(self, objective: Objective, obj_i: int) -> Iterator[Dict[str, Any]]:
Expand All @@ -209,7 +208,7 @@ def _infinite_train_objective_sampler(self, objective: Objective, obj_i: int) ->

dataset = objective.get_dataset("train", obj_i, self.args.device)
for sample in dataset:
self.objectives_outputs_queue.append(("train", sample["oid"]))
self.objectives_outputs_queue.append(("train", objective.get_id()))
yield sample

def _sample_objective_dataset(self, objective: Objective, obj_i: int, split: str) -> Iterator[Dict[str, Any]]:
Expand Down
3 changes: 1 addition & 2 deletions tests/evaluators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def assert_evaluator_logs(objective: Objective, split: str) -> None:
dataset_sample = next(iter(objective.get_dataset(split, objective_i=0, device="cpu")))

# request objective for its loss
loss = objective.compute_loss({k: v for k, v in dataset_sample.items() if k not in ("oid",)},
dataset_sample["labels"], split)
loss = objective.compute_loss(dataset_sample, dataset_sample["labels"], split)
assert loss.item()

log = objective.per_objective_log(split)
Expand Down
3 changes: 1 addition & 2 deletions tests/objectives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def assert_module_objective_ok(lang_module: LangModule, objective: Objective, sp
# outputs = lang_module(**dataset_sample)

# loss computation test, possible label smoothing is performed by Adapter
loss = objective.compute_loss({k: v for k, v in dataset_sample.items() if k not in ("oid",)},
dataset_sample["labels"], split)
loss = objective.compute_loss(dataset_sample, dataset_sample["labels"], split)

# check that retrieved loss has a backward_fn
loss.backward()
Expand Down

0 comments on commit 1a77e7c

Please sign in to comment.