From fd1d65c59b1d2fc9b29c05d1bee806297e1f3b0e Mon Sep 17 00:00:00 2001 From: Hugo Chateau-Laurent Date: Tue, 8 Oct 2024 16:12:31 +0200 Subject: [PATCH] Remove GWEncoderLinear (#163) Removing the GWEncoderLinear class because it is not used. --- shimmer/__init__.py | 2 -- shimmer/modules/__init__.py | 2 -- shimmer/modules/global_workspace.py | 23 +++++++++++++++++++---- shimmer/modules/gw_module.py | 23 +++++++++-------------- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 090c0b1c..db93b88e 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -24,7 +24,6 @@ from shimmer.modules.gw_module import ( GWDecoder, GWEncoder, - GWEncoderLinear, GWModule, GWModuleBase, GWModulePrediction, @@ -77,7 +76,6 @@ "DomainModule", "GWDecoder", "GWEncoder", - "GWEncoderLinear", "GWModuleBase", "GWModule", "GWModulePrediction", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index edffb919..90f7171b 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -18,7 +18,6 @@ from shimmer.modules.gw_module import ( GWDecoder, GWEncoder, - GWEncoderLinear, GWModule, GWModuleBase, GWModulePrediction, @@ -56,7 +55,6 @@ "DomainModule", "GWDecoder", "GWEncoder", - "GWEncoderLinear", "GWModuleBase", "GWModule", "GWModulePrediction", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a89ee0dc..24d18fab 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from enum import Enum, auto from pathlib import Path from typing import Any, Generic, TypedDict, TypeVar, cast @@ -663,6 +663,7 @@ def __init__( scheduler: LRScheduler | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, + fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, ) -> None: """ Initializes a Global Workspace @@ -689,10 +690,14 @@ def __init__( contrastive losses. scheduler: The scheduler to use for traning. If None is explicitely given, no scheduler will be used. Defaults to use OneCycleScheduler + fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation + function to fuse the domains. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale @@ -736,6 +741,7 @@ def __init__( scheduler: LRScheduler | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, + fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, ) -> None: """ Initializes a Global Workspace @@ -764,9 +770,13 @@ def __init__( contrastive losses. scheduler: The scheduler to use for traning. If None is explicitely given, no scheduler will be used. Defaults to use OneCycleScheduler + fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation + function to fuse the domains. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( @@ -800,6 +810,7 @@ def pretrained_global_workspace( scheduler: LRScheduler | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, + fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, **kwargs, ) -> GlobalWorkspace2Domains: """ @@ -823,6 +834,8 @@ def pretrained_global_workspace( contrastive losses. scheduler: The scheduler to use for traning. If None is explicitely given, no scheduler will be used. Defaults to use OneCycleScheduler + fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation + function to fuse the domains. **kwargs: additional arguments to pass to `GlobalWorkspace.load_from_checkpoint`. @@ -833,7 +846,9 @@ def pretrained_global_workspace( `TypeError`: if loaded type is not `GlobalWorkspace`. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) selection_mod = SingleDomainSelection() loss_mod = GWLosses2Domains( gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index f7dedb5f..90b9abcc 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from typing import TypedDict import torch @@ -180,11 +180,7 @@ def __init__( class GWEncoder(GWDecoder): - """ - An Encoder network used in GWModules. - - This is similar to the decoder, but adds a tanh non-linearity at the end. - """ + """An Encoder network used in GWModules.""" def __init__( self, @@ -209,13 +205,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input) -class GWEncoderLinear(nn.Linear): - """A linear Encoder network used in GWModules.""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.tanh(super().forward(input)) - - class GWModulePrediction(TypedDict): """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" @@ -368,6 +357,7 @@ def __init__( workspace_dim: int, gw_encoders: Mapping[str, nn.Module], gw_decoders: Mapping[str, nn.Module], + fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, ) -> None: """ Initializes the GWModule. @@ -381,6 +371,8 @@ def __init__( gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain name to a an torch.nn.Module class that decodes a GW representation to a unimodal latent representation. + fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation + function used to fuse the domains. """ super().__init__(domain_modules, workspace_dim) @@ -390,6 +382,9 @@ def __init__( self.gw_decoders = nn.ModuleDict(gw_decoders) """The module's decoders""" + self.fusion_activation_fn = fusion_activation_fn + """Activation function used to fuse the domains.""" + def fuse( self, x: LatentsDomainGroupT, @@ -405,7 +400,7 @@ def fuse( Returns: `torch.Tensor`: The merged representation. """ - return torch.tanh( + return self.fusion_activation_fn( torch.sum( torch.stack( [