Skip to content

Commit

Permalink
Objective.get_dataset: automatic resolution of device and num of eval…
Browse files Browse the repository at this point in the history
… samples
  • Loading branch information
stefanik12 committed Mar 26, 2024
1 parent db96268 commit eb1f952
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import itertools
import logging
from functools import partial
from typing import List, Union, Optional, Iterable, Tuple, Dict, Sequence, Any, Iterator

import torch
Expand Down Expand Up @@ -178,7 +179,6 @@ def per_objective_log(self, split: str) -> Dict[str, float]:

for evaluator in self.evaluators[split]:
dataset = self.get_dataset(split, 0, self.compatible_head_model.device,
firstn=self.max_samples_per_log[split],
add_oid=False,
is_training_dataset=False)
# evaluator should already return an aggregated value, so unlike loss, we don't average it
Expand Down Expand Up @@ -289,8 +289,7 @@ def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict
def get_dataset(self,
split: str,
objective_i: int,
device: Union[str, torch.device],
firstn: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
add_oid: bool = True,
is_training_dataset: bool = True,
show_progressbar: bool = True) -> TransformerAdaptationDataset:
Expand All @@ -299,9 +298,9 @@ def get_dataset(self,
:param split: A split of the retrieved dataset. `train` or `eval`.
:param objective_i: Rank of this objective in schedule. Used only to properly set up progress bar.
:param device: Device to transfer this data set to.
:param firstn: If given, a number of the retrieved items from the dataset.
:param add_oid: Whether to append objective id to the match. Required for forward pass over LangModule.
:param is_training_dataset: Whether this dataset is used for training -> if to update the epochs counter.
:param show_progressbar: Whether to maintain a dataset iterator progress bar for this objective.
:return: TransformerAdaptationDataset wrapping a data set of this objective.
"""
Expand All @@ -310,25 +309,16 @@ def get_dataset(self,
# - get_dataset is also called from self.per_objective_log, or specific objectives
self.epoch += 1 if split == "train" else 0

if show_progressbar:
self.progressbar[split] = trange(self.dataset_length[split] // self.batch_size,
desc=str(self),
unit="batches",
position=objective_i,
leave=True)
self.progressbar[split].set_postfix(refresh=False, split=split, epoch=self.epoch, loss=-1)
else:
# we do not update loss, if no progress bar is pertained
self.progressbar[split] = None

inputs_iter = self._get_inputs_iterator(split)

def _sample_to_device(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
return {k: v.to(device) if k != "oid" else v for k, v in sample.items()}

def _add_oid(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
sample["oid"] = torch.tensor(id(self))
return sample
def _sample_to_device(chosen_device: Optional[Union[str, torch.device]],
sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
if chosen_device is None:
# default device is a device of the model assigned to this objective, if it is set
# if it is not, we resort to "cpu"
# in classic training, the model is always assigned when the dataset is requested
chosen_device = self.compatible_head_model.device if self.compatible_head_model is not None else "cpu"
return {k: v.to(chosen_device) for k, v in sample.items()}

def _remember_input(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
self.last_input = sample
Expand All @@ -338,19 +328,35 @@ def _update_pbar(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> D
self.progressbar[split].update(1)
return sample

device_inputs_iter = map(_sample_to_device, inputs_iter)
def _add_oid(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
sample["oid"] = torch.tensor(id(self))
return sample

device_inputs_iter = map(partial(_sample_to_device, device), inputs_iter)

if split == "eval" and self.max_samples_per_log["eval"] is not None:
device_inputs_iter = itertools.islice(device_inputs_iter, self.max_samples_per_log["eval"])
self.dataset_length["eval"] = self.max_samples_per_log["eval"] * self.batch_size

if add_oid:
device_inputs_iter = map(_add_oid, device_inputs_iter)

if firstn is not None and firstn < self.dataset_length[split]:
device_inputs_iter = itertools.islice(device_inputs_iter, firstn)

if self.remember_last_input:
device_inputs_iter = map(_remember_input, device_inputs_iter)

if show_progressbar:
# set up a new progressbar object
self.progressbar[split] = trange(self.dataset_length[split] // self.batch_size,
desc=str(self),
unit="batches",
position=objective_i,
leave=True)
self.progressbar[split].set_postfix(refresh=False, split=split, epoch=self.epoch, loss=-1)
# assign a hook to the iterator, to update the progressbar on every yielded sample
device_inputs_iter = map(_update_pbar, device_inputs_iter)
else:
# we do not update loss, if no progress bar is pertained
self.progressbar[split] = None

return TransformerAdaptationDataset(device_inputs_iter, self.dataset_length[split])

Expand Down

0 comments on commit eb1f952

Please sign in to comment.