Skip to content

Commit

Permalink
scheduler should take optimizer as param
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 8, 2024
1 parent 760f7ee commit c2bf716
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from enum import Enum, auto
from pathlib import Path
from typing import Any, Generic, TypedDict, TypeVar, cast
Expand All @@ -9,6 +9,7 @@
from torch.nn import Module, ModuleDict
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from torch.optim.optimizer import Optimizer

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule, LossOutput
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
Expand All @@ -245,7 +246,8 @@ def __init__(
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
scheduler parameters.
scheduler: scheduler to use. If None is explicitely given, no scheduler
scheduler (`Callable[[Optimizer], LRScheduler]`): Callback that returns the
scheduler to use. If None is explicitely given, no scheduler
will be used. By default, uses OneCycleScheduler
"""
super().__init__()
Expand Down Expand Up @@ -603,7 +605,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
if isinstance(self.scheduler, OneCycleSchedulerSentinel):
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
else:
lr_scheduler = self.scheduler
lr_scheduler = self.scheduler(optimizer)

return {
"optimizer": optimizer,
Expand Down Expand Up @@ -660,7 +662,7 @@ def __init__(
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
Expand Down Expand Up @@ -733,7 +735,7 @@ def __init__(
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
Expand Down

0 comments on commit c2bf716

Please sign in to comment.