Skip to content

Commit

Permalink
Use torch.no_grad() instead of inference mode in `pytorch_trainer.p…
Browse files Browse the repository at this point in the history
…y` (#632)

Co-authored-by: Maximilian Böther <[email protected]>
Co-authored-by: Maximilian Böther <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2024
1 parent 5e0d789 commit 9c42e0f
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-name-in-module
from __future__ import annotations

import contextlib
import glob
import io
import itertools
Expand All @@ -23,6 +24,7 @@

from modyn.common.benchmark.stopwatch import Stopwatch
from modyn.models.coreset_methods_support import CoresetSupportingModule
from modyn.models.dlrm.dlrm import DLRM
from modyn.selector.internal.grpc.generated.selector_pb2 import (
AvailableLabelsResponse,
GetAvailableLabelsRequest,
Expand Down Expand Up @@ -560,7 +562,14 @@ def downsample_batch(
self._downsampler.init_downsampler()
self.start_embedding_recording_if_needed()

with torch.inference_mode(mode=not self._downsampler.requires_grad):
# DLRM does not support inference_mode(), as it will complain during training that
# "that inference tensors cannot be saved for backward".
# It could be that some DLRM parameters are lazily created during the
# first forward pass and hence they are created as inference tensors if inference mode is used here.
# If this becomes a problem for more models, we might want to make it a field on the model class instead.
no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode()
context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr
with context_manager:
big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
embeddings = self.get_embeddings_if_recorded()
self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings)
Expand Down Expand Up @@ -904,7 +913,9 @@ def _iterate_dataloader_and_compute_scores(
sample_ids, target, data = self.preprocess_batch(batch)
number_of_samples += len(sample_ids)

with torch.inference_mode(mode=not self._downsampler.requires_grad):
no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode()
context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr
with context_manager:
with torch.autocast(self._device_type, enabled=self._amp):
# compute the scores and accumulate them
model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor()
Expand Down

0 comments on commit 9c42e0f

Please sign in to comment.