diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 40eddb016..47acce439 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -1,6 +1,7 @@ # pylint: disable=no-name-in-module from __future__ import annotations +import contextlib import glob import io import itertools @@ -23,6 +24,7 @@ from modyn.common.benchmark.stopwatch import Stopwatch from modyn.models.coreset_methods_support import CoresetSupportingModule +from modyn.models.dlrm.dlrm import DLRM from modyn.selector.internal.grpc.generated.selector_pb2 import ( AvailableLabelsResponse, GetAvailableLabelsRequest, @@ -560,7 +562,14 @@ def downsample_batch( self._downsampler.init_downsampler() self.start_embedding_recording_if_needed() - with torch.inference_mode(mode=not self._downsampler.requires_grad): + # DLRM does not support inference_mode(), as it will complain during training that + # "that inference tensors cannot be saved for backward". + # It could be that some DLRM parameters are lazily created during the + # first forward pass and hence they are created as inference tensors if inference mode is used here. + # If this becomes a problem for more models, we might want to make it a field on the model class instead. + no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() + context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr + with context_manager: big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings) @@ -904,7 +913,9 @@ def _iterate_dataloader_and_compute_scores( sample_ids, target, data = self.preprocess_batch(batch) number_of_samples += len(sample_ids) - with torch.inference_mode(mode=not self._downsampler.requires_grad): + no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() + context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr + with context_manager: with torch.autocast(self._device_type, enabled=self._amp): # compute the scores and accumulate them model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()