Skip to content

Commit

Permalink
more generic outputs for training_step, validation_step and test_step (
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Sep 19, 2024
1 parent 2c9b60d commit 7d90ee9
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import (
OptimizerLRScheduler,
)
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from torch.nn import Module, ModuleDict
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
Expand Down Expand Up @@ -486,7 +484,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
for domains, latents in latents_domain.items()
}

def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT:
"""
The generic step used in `training_step`, `validation_step` and
`test_step`.
Expand Down Expand Up @@ -515,7 +513,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso

def validation_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Validation step used by lightning"""

batch = {frozenset(data.keys()): data}
Expand All @@ -527,7 +525,7 @@ def validation_step( # type: ignore

def test_step( # type: ignore
self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Test step used by lightning"""

batch = {frozenset(data.keys()): data}
Expand All @@ -539,7 +537,7 @@ def test_step( # type: ignore

def training_step( # type: ignore
self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Training step used by lightning"""

return self.generic_step(batch, mode="train")
Expand Down

0 comments on commit 7d90ee9

Please sign in to comment.