Skip to content

Commit

Permalink
Use generics for GlobalWorkspaceBase (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Apr 11, 2024
1 parent d1d965f commit ee16d5c
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import Any, TypedDict, cast
from typing import Any, Generic, TypedDict, TypeVar, cast

import torch
from lightning.pytorch import LightningModule
Expand Down Expand Up @@ -61,7 +61,14 @@ class GWPredictionsBase(TypedDict):
"""


class GlobalWorkspaceBase(LightningModule):
_T_gw_mod = TypeVar("_T_gw_mod", bound=GWModuleBase)
_T_selection_mod = TypeVar("_T_selection_mod", bound=SelectionBase)
_T_loss_mod = TypeVar("_T_loss_mod", bound=GWLossesBase)


class GlobalWorkspaceBase(
Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
):
"""
Global Workspace Lightning Module.
Expand All @@ -70,9 +77,9 @@ class GlobalWorkspaceBase(LightningModule):

def __init__(
self,
gw_mod: GWModuleBase,
selection_mod: SelectionBase,
loss_mod: GWLossesBase,
gw_mod: _T_gw_mod,
selection_mod: _T_selection_mod,
loss_mod: _T_loss_mod,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand Down Expand Up @@ -445,7 +452,7 @@ class GWPredictions(GWPredictionsBase):
"""


class GlobalWorkspace(GlobalWorkspaceBase):
class GlobalWorkspace(GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses]):
"""
A simple 2-domains max flavor of GlobalWorkspaceBase.
Expand Down Expand Up @@ -538,17 +545,18 @@ def forward( # type: ignore
)


class GlobalWorkspaceWithUncertainty(GlobalWorkspaceBase):
class GlobalWorkspaceWithUncertainty(
GlobalWorkspaceBase[
GWModuleWithUncertainty, SingleDomainSelection, GWLossesWithUncertainty
]
):
"""
A simple 2-domains max GlobalWorkspaceBase with uncertainty.
This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
"""

gw_mod: GWModuleWithUncertainty
selection_mod: SingleDomainSelection

def __init__(
self,
domain_mods: Mapping[str, DomainModule],
Expand Down Expand Up @@ -642,7 +650,9 @@ def forward( # type: ignore
)


class GlobalWorkspaceFusion(GlobalWorkspaceBase):
class GlobalWorkspaceFusion(
GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion]
):
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
Expand Down

0 comments on commit ee16d5c

Please sign in to comment.