Skip to content

Commit

Permalink
fix: global workspace scheduler arg as a callback (#169)
Browse files Browse the repository at this point in the history
The previous implementation (#132) could not work with new scheduler as
it could
not take the optimizer as param.
Here, the scheduler should be a callback instead:

```python
def get_scheduler(optimizer: Optimizer) -> LRScheduler:
    return StepLR(otimizer, ...)

gw = GlobalWorkspace(
        ...
        scheduler=get_scheduler
)
```
  • Loading branch information
bdvllrs authored Oct 11, 2024
1 parent 9b50160 commit 5369d9e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn import Module, ModuleDict
from torch.optim.adamw import AdamW
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from torch.optim.optimizer import Optimizer

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule, LossOutput
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
Expand All @@ -245,7 +246,8 @@ def __init__(
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
scheduler parameters.
scheduler: scheduler to use. If None is explicitely given, no scheduler
scheduler (`Callable[[Optimizer], LRScheduler]`): Callback that returns the
scheduler to use. If None is explicitely given, no scheduler
will be used. By default, uses OneCycleScheduler
"""
super().__init__()
Expand Down Expand Up @@ -604,7 +606,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
if isinstance(self.scheduler, OneCycleSchedulerSentinel):
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
else:
lr_scheduler = self.scheduler
lr_scheduler = self.scheduler(optimizer)

return {
"optimizer": optimizer,
Expand Down Expand Up @@ -661,7 +663,7 @@ def __init__(
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
Expand Down Expand Up @@ -739,7 +741,7 @@ def __init__(
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
scheduler: Callable[[Optimizer], LRScheduler]
| None
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
Expand Down

0 comments on commit 5369d9e

Please sign in to comment.