diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 95c7f250..c0df2fa3 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -5,9 +5,7 @@ import torch from lightning.pytorch import LightningModule -from lightning.pytorch.utilities.types import ( - OptimizerLRScheduler, -) +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler from torch.nn import Module, ModuleDict from torch.optim.adamw import AdamW from torch.optim.lr_scheduler import LRScheduler, OneCycleLR @@ -486,7 +484,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup for domains, latents in latents_domain.items() } - def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor: + def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT: """ The generic step used in `training_step`, `validation_step` and `test_step`. @@ -515,7 +513,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso def validation_step( # type: ignore self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: + ) -> STEP_OUTPUT: """Validation step used by lightning""" batch = {frozenset(data.keys()): data} @@ -527,7 +525,7 @@ def validation_step( # type: ignore def test_step( # type: ignore self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: + ) -> STEP_OUTPUT: """Test step used by lightning""" batch = {frozenset(data.keys()): data} @@ -539,7 +537,7 @@ def test_step( # type: ignore def training_step( # type: ignore self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int - ) -> torch.Tensor: + ) -> STEP_OUTPUT: """Training step used by lightning""" return self.generic_step(batch, mode="train")