diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 5e8e162b..c3197ac9 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -10,7 +10,7 @@ ContrastiveLossType, contrastive_loss, ) -from shimmer.modules.domain import DomainModule, End2EndDomainModule, LossOutput +from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.global_workspace import ( GlobalWorkspace2Domains, GlobalWorkspaceBase, @@ -79,7 +79,6 @@ "pretrained_global_workspace", "LossOutput", "DomainModule", - "End2EndDomainModule", "GWDecoder", "GWEncoder", "GWEncoderLinear", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index ee5e1037..7d3a18ec 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -5,7 +5,7 @@ ContrastiveLossType, contrastive_loss, ) -from shimmer.modules.domain import DomainModule, End2EndDomainModule, LossOutput +from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.global_workspace import ( GlobalWorkspace2Domains, GlobalWorkspaceBase, @@ -59,7 +59,6 @@ "pretrained_global_workspace", "LossOutput", "DomainModule", - "End2EndDomainModule", "GWDecoder", "GWEncoder", "GWEncoderLinear", diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index 24baa1ec..c909f010 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -71,6 +71,24 @@ def __init__( self.latent_dim = latent_dim """The latent dimension of the module.""" + self.is_frozen: bool | None = None + """ Whether the module is frozen. If None, it is frozen by default. """ + + def freeze(self) -> None: + """ + Freezes the module. This is the default mode. + """ + self.is_frozen = True + return super().freeze() + + def unfreeze(self) -> None: + """ + Unfreezes the module. This is usefull to train the domain module end-to-end. + This also unlocks `compute_domain_loss` during training. + """ + self.is_frozen = False + return super().unfreeze() + def encode(self, x: Any) -> torch.Tensor: """ Encode the domain data into a unimodal representation. @@ -186,8 +204,6 @@ def compute_fused_loss( """ return self.compute_loss(pred, target, raw_target) - -class End2EndDomainModule(DomainModule): def compute_domain_loss(self, domain: Any) -> LossOutput | None: """ Compute the unimodal domain loss. diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 13aa9584..ee03690a 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -11,7 +11,7 @@ from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType -from shimmer.modules.domain import DomainModule, End2EndDomainModule, LossOutput +from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.gw_module import ( GWModule, GWModuleBase, @@ -492,7 +492,7 @@ def unimodal_losses(self, batch: RawDomainGroupsT) -> LossOutput | None: continue for domain_name, domain in domain_group.items(): domain_mod = self.domain_mods[domain_name] - if isinstance(domain_mod, End2EndDomainModule): + if not domain_mod.is_frozen: loss = domain_mod.compute_domain_loss(domain) if loss is not None: for name, metric in loss.metrics.items(): @@ -628,10 +628,6 @@ def freeze_domain_modules( The output is casted as `dict[str, DomainModule]` type for better auto-completion, but is actually a torch `ModuleDict`. - .. note:: - Instances of `End2EndDomainModule` are not frozen as they should be trained - alongside the GW. - Args: domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze @@ -640,7 +636,7 @@ def freeze_domain_modules( """ for mod in domain_mods.values(): - if not isinstance(mod, End2EndDomainModule): + if mod.is_frozen is None: mod.freeze() # Cast for better auto-completion at the expense of ModuleDict return cast(dict[str, DomainModule], ModuleDict(domain_mods))