Skip to content

Commit

Permalink
make sentinel public
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Sep 18, 2024
1 parent 77c6d9d commit 64cdc30
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def batch_broadcasts(
return predictions, cycles


class _OneCycleSchedulerSentinel(Enum):
class OneCycleSchedulerSentinel(Enum):
"""
Used for backward-compatibility issues to use One-Cycle Scheduler by default
"""
Expand All @@ -234,7 +234,7 @@ def __init__(
scheduler_args: SchedulerArgs | None = None,
scheduler: LRScheduler
| None
| _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT,
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a GW
Expand Down Expand Up @@ -571,7 +571,7 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
return {"optimizer": optimizer}

lr_scheduler: LRScheduler
if isinstance(self.scheduler, _OneCycleSchedulerSentinel):
if isinstance(self.scheduler, OneCycleSchedulerSentinel):
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
else:
lr_scheduler = self.scheduler
Expand Down Expand Up @@ -632,7 +632,7 @@ def __init__(
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
| None
| _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT,
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -705,7 +705,7 @@ def __init__(
contrastive_loss: ContrastiveLossType | None = None,
scheduler: LRScheduler
| None
| _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT,
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -788,7 +788,7 @@ def __init__(
precision_softmax_temp: float = 0.01,
scheduler: LRScheduler
| None
| _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT,
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -870,7 +870,7 @@ def pretrained_global_workspace(
contrastive_fn: ContrastiveLossType,
scheduler: LRScheduler
| None
| _OneCycleSchedulerSentinel = _OneCycleSchedulerSentinel.DEFAULT,
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
**kwargs,
) -> GlobalWorkspace2Domains:
"""
Expand Down

0 comments on commit 64cdc30

Please sign in to comment.