diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a89ee0dc..4aa6d2ec 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from enum import Enum, auto from pathlib import Path from typing import Any, Generic, TypedDict, TypeVar, cast @@ -9,6 +9,7 @@ from torch.nn import Module, ModuleDict from torch.optim.adamw import AdamW from torch.optim.lr_scheduler import LRScheduler, OneCycleLR +from torch.optim.optimizer import Optimizer from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType from shimmer.modules.domain import DomainModule, LossOutput @@ -230,7 +231,7 @@ def __init__( optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: @@ -245,7 +246,8 @@ def __init__( optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define scheduler parameters. - scheduler: scheduler to use. If None is explicitely given, no scheduler + scheduler (`Callable[[Optimizer], LRScheduler]`): Callback that returns the + scheduler to use. If None is explicitely given, no scheduler will be used. By default, uses OneCycleScheduler """ super().__init__() @@ -603,7 +605,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: if isinstance(self.scheduler, OneCycleSchedulerSentinel): lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) else: - lr_scheduler = self.scheduler + lr_scheduler = self.scheduler(optimizer) return { "optimizer": optimizer, @@ -660,7 +662,7 @@ def __init__( scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: @@ -733,7 +735,7 @@ def __init__( scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: