diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 01509dd..02a37b6 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -9,6 +9,7 @@ from torch.nn import Module, ModuleDict from torch.optim.adamw import AdamW from torch.optim.lr_scheduler import LRScheduler, OneCycleLR +from torch.optim.optimizer import Optimizer from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType from shimmer.modules.domain import DomainModule, LossOutput @@ -230,7 +231,7 @@ def __init__( optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: @@ -245,7 +246,8 @@ def __init__( optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define scheduler parameters. - scheduler: scheduler to use. If None is explicitely given, no scheduler + scheduler (`Callable[[Optimizer], LRScheduler]`): Callback that returns the + scheduler to use. If None is explicitely given, no scheduler will be used. By default, uses OneCycleScheduler """ super().__init__() @@ -604,7 +606,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: if isinstance(self.scheduler, OneCycleSchedulerSentinel): lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) else: - lr_scheduler = self.scheduler + lr_scheduler = self.scheduler(optimizer) return { "optimizer": optimizer, @@ -661,7 +663,7 @@ def __init__( scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, @@ -739,7 +741,7 @@ def __init__( scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, - scheduler: LRScheduler + scheduler: Callable[[Optimizer], LRScheduler] | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,