Skip to content

Commit

Permalink
more general output for configure_optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Sep 18, 2024
1 parent 64cdc30 commit 149d9b8
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from lightning.pytorch.utilities.types import (
OptimizerLRScheduler,
)
from torch.nn import Module, ModuleDict
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
Expand Down Expand Up @@ -553,15 +556,15 @@ def predict_step( # type: ignore
domain_latents = self.encode_domains(batch)
return self.forward(domain_latents)

def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
def configure_optimizers(self) -> OptimizerLRScheduler:
"""
Configure models optimizers.
Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
scheduler.
"""

optimizer = torch.optim.AdamW(
optimizer = AdamW(
self.parameters(),
lr=self.optim_lr,
weight_decay=self.optim_weight_decay,
Expand Down

0 comments on commit 149d9b8

Please sign in to comment.