Skip to content

Commit

Permalink
more generic outputs for training_step, validation_step and test_step
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Sep 19, 2024
1 parent 3e78b27 commit 58197d9
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRSchedulerConfig
from torch.nn import Module, ModuleDict
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
Expand Down Expand Up @@ -467,7 +468,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
for domains, latents in latents_domain.items()
}

def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT:
"""
The generic step used in `training_step`, `validation_step` and
`test_step`.
Expand Down Expand Up @@ -496,7 +497,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso

def validation_step( # type: ignore
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Validation step used by lightning"""

batch = {frozenset(data.keys()): data}
Expand All @@ -508,7 +509,7 @@ def validation_step( # type: ignore

def test_step( # type: ignore
self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Test step used by lightning"""

batch = {frozenset(data.keys()): data}
Expand All @@ -520,7 +521,7 @@ def test_step( # type: ignore

def training_step( # type: ignore
self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
) -> torch.Tensor:
) -> STEP_OUTPUT:
"""Training step used by lightning"""

return self.generic_step(batch, mode="train")
Expand All @@ -545,7 +546,7 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
scheduler.
"""

optimizer = torch.optim.AdamW(
optimizer = AdamW(
self.parameters(),
lr=self.optim_lr,
weight_decay=self.optim_weight_decay,
Expand Down

0 comments on commit 58197d9

Please sign in to comment.