From 9c42e0f5c89b8d63559aa2c13a515976568999d3 Mon Sep 17 00:00:00 2001 From: Xianzhe Ma Date: Thu, 12 Sep 2024 16:57:01 +0800 Subject: [PATCH] Use `torch.no_grad()` instead of inference mode in `pytorch_trainer.py` (#632) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Maximilian Böther Co-authored-by: Maximilian Böther <2116466+MaxiBoether@users.noreply.github.com> --- .../internal/trainer/pytorch_trainer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 40eddb016..47acce439 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -1,6 +1,7 @@ # pylint: disable=no-name-in-module from __future__ import annotations +import contextlib import glob import io import itertools @@ -23,6 +24,7 @@ from modyn.common.benchmark.stopwatch import Stopwatch from modyn.models.coreset_methods_support import CoresetSupportingModule +from modyn.models.dlrm.dlrm import DLRM from modyn.selector.internal.grpc.generated.selector_pb2 import ( AvailableLabelsResponse, GetAvailableLabelsRequest, @@ -560,7 +562,14 @@ def downsample_batch( self._downsampler.init_downsampler() self.start_embedding_recording_if_needed() - with torch.inference_mode(mode=not self._downsampler.requires_grad): + # DLRM does not support inference_mode(), as it will complain during training that + # "that inference tensors cannot be saved for backward". + # It could be that some DLRM parameters are lazily created during the + # first forward pass and hence they are created as inference tensors if inference mode is used here. + # If this becomes a problem for more models, we might want to make it a field on the model class instead. + no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() + context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr + with context_manager: big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings) @@ -904,7 +913,9 @@ def _iterate_dataloader_and_compute_scores( sample_ids, target, data = self.preprocess_batch(batch) number_of_samples += len(sample_ids) - with torch.inference_mode(mode=not self._downsampler.requires_grad): + no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() + context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr + with context_manager: with torch.autocast(self._device_type, enabled=self._amp): # compute the scores and accumulate them model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()