diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 534af8b0..f13abdb8 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -207,7 +207,7 @@ def batch_broadcasts( return predictions, cycles -class _OneCycleSchedulerSentinel(Enum): +class OneCycleSchedulerSentinel(Enum): """ Used for backward-compatibility issues to use One-Cycle Scheduler by default """ @@ -234,7 +234,7 @@ def __init__( scheduler_args: SchedulerArgs | None = None, scheduler: LRScheduler | None - | _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT, + | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: """ Initializes a GW @@ -571,7 +571,7 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig: return {"optimizer": optimizer} lr_scheduler: LRScheduler - if isinstance(self.scheduler, _OneCycleSchedulerSentinel): + if isinstance(self.scheduler, OneCycleSchedulerSentinel): lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) else: lr_scheduler = self.scheduler @@ -632,7 +632,7 @@ def __init__( contrastive_loss: ContrastiveLossType | None = None, scheduler: LRScheduler | None - | _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT, + | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: """ Initializes a Global Workspace @@ -705,7 +705,7 @@ def __init__( contrastive_loss: ContrastiveLossType | None = None, scheduler: LRScheduler | None - | _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT, + | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: """ Initializes a Global Workspace @@ -788,7 +788,7 @@ def __init__( precision_softmax_temp: float = 0.01, scheduler: LRScheduler | None - | _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT, + | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, ) -> None: """ Initializes a Global Workspace @@ -870,7 +870,7 @@ def pretrained_global_workspace( contrastive_fn: ContrastiveLossType, scheduler: LRScheduler | None - | _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT, + | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, **kwargs, ) -> GlobalWorkspace2Domains: """