diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a9e042e0..8c5aa3fe 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -545,13 +545,10 @@ def forward( # type: ignore ) -class GlobalWorkspaceWithUncertainty( - GlobalWorkspaceBase[ - GWModuleWithUncertainty, SingleDomainSelection, GWLossesWithUncertainty - ] +class GlobalWorkspaceFusion( + GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion] ): - """ - A simple 2-domains max GlobalWorkspaceBase with uncertainty. + """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. This is used to simplify a Global Workspace instanciation and only overrides the `__init__` method. @@ -563,7 +560,8 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: LossCoefs, + loss_coefs: BroadcastLossCoefs, + selection_temperature: float = 0.2, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -584,7 +582,9 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients + loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. + selection_temperature (`float`): temperature value for the RandomSelection + module. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -595,23 +595,16 @@ def __init__( contrastive losses. """ domain_mods = freeze_domain_modules(domain_mods) + gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) - gw_mod = GWModuleWithUncertainty( - domain_mods, workspace_dim, gw_encoders, gw_decoders - ) - - selection_mod = SingleDomainSelection() - - contrastive_loss = ContrastiveLoss( - torch.tensor([1]).log(), "mean", learn_logit_scale - ) + if contrastive_loss is None: + contrastive_loss = ContrastiveLoss( + torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale + ) - loss_mod = GWLossesWithUncertainty( - gw_mod, - selection_mod, - domain_mods, - loss_coefs, - contrastive_loss, + selection_mod = RandomSelection(selection_temperature) + loss_mod = GWLossesFusion( + gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss ) super().__init__( @@ -623,37 +616,14 @@ def __init__( scheduler_args, ) - def forward( # type: ignore - self, - latent_domains: LatentsDomainGroupsT, - ) -> GWPredictions: - """ - Computes demi-cycles, cycles, and translations. - - Args: - latent_domains (`LatentsT`): Groups of domains for the computation. - - Returns: - `GWPredictions`: the predictions on the batch. - """ - return GWPredictions( - demi_cycles=batch_demi_cycles( - self.gw_mod, self.selection_mod, latent_domains - ), - cycles=batch_cycles( - self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() - ), - translations=batch_translations( - self.gw_mod, self.selection_mod, latent_domains - ), - **super().forward(latent_domains), - ) - -class GlobalWorkspaceFusion( - GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion] +class GlobalWorkspaceWithUncertainty( + GlobalWorkspaceBase[ + GWModuleWithUncertainty, RandomSelection, GWLossesWithUncertainty + ] ): - """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. + """ + A simple 2-domains max GlobalWorkspaceBase with uncertainty. This is used to simplify a Global Workspace instanciation and only overrides the `__init__` method. @@ -665,7 +635,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: BroadcastLossCoefs, + loss_coefs: LossCoefs, selection_temperature: float = 0.2, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, @@ -687,9 +657,8 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. - selection_temperature (`float`): temperature value for the RandomSelection - module. + loss_coefs (`LossCoefs`): loss coefficients + selection_temperature (`float`): temperature for `RandomSelection` optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -700,16 +669,23 @@ def __init__( contrastive losses. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) - if contrastive_loss is None: - contrastive_loss = ContrastiveLoss( - torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale - ) + gw_mod = GWModuleWithUncertainty( + domain_mods, workspace_dim, gw_encoders, gw_decoders + ) selection_mod = RandomSelection(selection_temperature) - loss_mod = GWLossesFusion( - gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss + + contrastive_loss = ContrastiveLoss( + torch.tensor([1]).log(), "mean", learn_logit_scale + ) + + loss_mod = GWLossesWithUncertainty( + gw_mod, + selection_mod, + domain_mods, + loss_coefs, + contrastive_loss, ) super().__init__( @@ -721,6 +697,32 @@ def __init__( scheduler_args, ) + def forward( # type: ignore + self, + latent_domains: LatentsDomainGroupsT, + ) -> GWPredictions: + """ + Computes demi-cycles, cycles, and translations. + + Args: + latent_domains (`LatentsT`): Groups of domains for the computation. + + Returns: + `GWPredictions`: the predictions on the batch. + """ + return GWPredictions( + demi_cycles=batch_demi_cycles( + self.gw_mod, self.selection_mod, latent_domains + ), + cycles=batch_cycles( + self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() + ), + translations=batch_translations( + self.gw_mod, self.selection_mod, latent_domains + ), + **super().forward(latent_domains), + ) + def pretrained_global_workspace( checkpoint_path: str | Path,