diff --git a/CHANGELOG.md b/CHANGELOG.md index 92331f8c..bbcccb0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,3 +36,8 @@ refers to `DeterministicGlobalWorkspace`. * Rename every abstract class with ClassNameBase. Rename every "Deterministic" classes to remove "Deterministic". * Remove all config related functions. This is not the role of this repo. + +# 0.4.1 +* Remove `GWInterfaces` entirely and favor giving encoders and decoders directly to the + `GWModule`. See the updated example `examples/main_example/train_gw.py` to see what + changes to make. diff --git a/docs/assets/shimmer_architecture.png b/docs/assets/shimmer_architecture.png index ab7efe5e..bcc9ec80 100755 Binary files a/docs/assets/shimmer_architecture.png and b/docs/assets/shimmer_architecture.png differ diff --git a/docs/shimmer_basics.md b/docs/shimmer_basics.md index 3929e2c3..9d35b175 100644 --- a/docs/shimmer_basics.md +++ b/docs/shimmer_basics.md @@ -15,9 +15,7 @@ to make a GW in shimmer: Let's detail: - [`DomainModule`](https://bdvllrs.github.io/shimmer/shimmer.html#DomainModule)s are the individual domain modules which encode domain data into a latent vector; -- `GWInterface`s are links to encode one domain in a GW representation; -- the `GWModule` has access to all `GWInterface`s and defines how to encode, decode and -merge representations of the domains into a unique GW representation. +- the `GWModule` has access to the domain modules, and defines how to encode, decode and merge representations of the domains into a unique GW representation. - finally `GlobalWorkspaceBase` takes all building blocks to make a [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) module The last building block (not in the diagram) is the `GWLosses` class which @@ -418,8 +416,9 @@ from dataset import GWDataModule, get_domain_data, make_datasets from domains import GenericDomain from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint +from torch import nn -from shimmer import GlobalWorkspace, GWInterface, LossCoefs +from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs def train_gw(): @@ -451,20 +450,24 @@ def train_gw(): workspace_dim = 16 - # Now we define interfaces that will encode and decode the domain representations - # to and from the global workspace - # We will use the already defined GWInterface class - gw_interfaces: dict[str, GWInterface] = {} + # Now we define modality encoders and decoders that will encode and decode + # the domain representations to and from the global workspace + gw_encoders: dict[str, nn.Module] = {} + gw_decoders: dict[str, nn.Module] = {} for name, mod in domain_mods.items(): - gw_interfaces[name] = GWInterface( - mod, - workspace_dim, - encoder_hidden_dim=64, + gw_encoders[name] = GWEncoder( + mod.latent_dim, + hidden_dim=64, + out_dim=workspace_dim, # total number of Linear layers is this value + 2 (one before, one after) - encoder_n_layers=1, - decoder_hidden_dim=64, + n_layers=1, + ) + gw_decoders[name] = GWDecoder( + in_dim=workspace_dim, + hidden_dim=64, + out_dim=mod.latent_dim, # total number of Linear layers is this value + 2 (one before, one after) - decoder_n_layers=1, + n_layers=1, ) loss_coefs: LossCoefs = { @@ -475,7 +478,7 @@ def train_gw(): } global_workspace = GlobalWorkspace( - domain_mods, gw_interfaces, workspace_dim, loss_coefs + domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs ) trainer = Trainer( @@ -535,24 +538,28 @@ This should be the same as what was used for the data. } ``` -We create the `GWInterfaces` to link the domain modules with the GlobalWorkspace +We define encoders and decoders to link the domain modules with the GlobalWorkspace ```python workspace_dim = 16 - # Now we define interfaces that will encode and decode the domain representations - # to and from the global workspace - # We will use the already defined GWInterface class - gw_interfaces: dict[str, GWInterface] = {} + # Now we define modality encoders and decoders that will encode and decode + # the domain representations to and from the global workspace + gw_encoders: dict[str, nn.Module] = {} + gw_decoders: dict[str, nn.Module] = {} for name, mod in domain_mods.items(): - gw_interfaces[name] = GWInterface( - mod, - workspace_dim, - encoder_hidden_dim=64, + gw_encoders[name] = GWEncoder( + mod.latent_dim, + hidden_dim=64, + out_dim=workspace_dim, # total number of Linear layers is this value + 2 (one before, one after) - encoder_n_layers=1, - decoder_hidden_dim=64, + n_layers=1, + ) + gw_decoders[name] = GWDecoder( + in_dim=workspace_dim, + hidden_dim=64, + out_dim=mod.latent_dim, # total number of Linear layers is this value + 2 (one before, one after) - decoder_n_layers=1, + n_layers=1, ) ``` @@ -570,7 +577,7 @@ We define loss coefficients for the different losses. Note that `LossCoefs` is a Finally we make the GlobalWorkspace and train it. ```python global_workspace = GlobalWorkspace( - domain_mods, gw_interfaces, workspace_dim, loss_coefs + domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs ) trainer = Trainer( diff --git a/examples/main_example/train_gw.py b/examples/main_example/train_gw.py index 516a6b9f..76d8eaca 100644 --- a/examples/main_example/train_gw.py +++ b/examples/main_example/train_gw.py @@ -2,8 +2,9 @@ from domains import GenericDomain from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint +from torch import nn -from shimmer import GlobalWorkspace, GWInterface, LossCoefs +from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs def train_gw(): @@ -35,20 +36,24 @@ def train_gw(): workspace_dim = 16 - # Now we define interfaces that will encode and decode the domain representations - # to and from the global workspace - # We will use the already defined GWInterface class - gw_interfaces: dict[str, GWInterface] = {} + # Now we define modality encoders and decoders that will encode and decode + # the domain representations to and from the global workspace + gw_encoders: dict[str, nn.Module] = {} + gw_decoders: dict[str, nn.Module] = {} for name, mod in domain_mods.items(): - gw_interfaces[name] = GWInterface( - mod, - workspace_dim, - encoder_hidden_dim=64, + gw_encoders[name] = GWEncoder( + mod.latent_dim, + hidden_dim=64, + out_dim=workspace_dim, # total number of Linear layers is this value + 2 (one before, one after) - encoder_n_layers=1, - decoder_hidden_dim=64, + n_layers=1, + ) + gw_decoders[name] = GWDecoder( + in_dim=workspace_dim, + hidden_dim=64, + out_dim=mod.latent_dim, # total number of Linear layers is this value + 2 (one before, one after) - decoder_n_layers=1, + n_layers=1, ) loss_coefs: LossCoefs = { @@ -59,7 +64,7 @@ def train_gw(): } global_workspace = GlobalWorkspace( - domain_mods, gw_interfaces, workspace_dim, loss_coefs + domain_mods, gw_encoders, gw_decoders, workspace_dim, loss_coefs ) trainer = Trainer( diff --git a/shimmer/__init__.py b/shimmer/__init__.py index dcd79e93..124065a2 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -19,12 +19,10 @@ from shimmer.modules.gw_module import ( GWDecoder, GWEncoder, - GWInterface, - GWInterfaceBase, + GWEncoderLinear, GWModule, GWModuleBase, VariationalGWEncoder, - VariationalGWInterface, VariationalGWModule, ) from shimmer.modules.losses import ( @@ -68,10 +66,8 @@ "DomainModule", "GWDecoder", "GWEncoder", + "GWEncoderLinear", "VariationalGWEncoder", - "GWInterfaceBase", - "GWInterface", - "VariationalGWInterface", "GWModuleBase", "GWModule", "VariationalGWModule", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index b7b34b5c..bcd0a6af 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -19,12 +19,10 @@ from shimmer.modules.gw_module import ( GWDecoder, GWEncoder, - GWInterface, - GWInterfaceBase, + GWEncoderLinear, GWModule, GWModuleBase, VariationalGWEncoder, - VariationalGWInterface, VariationalGWModule, ) from shimmer.modules.losses import ( @@ -54,10 +52,8 @@ "DomainModule", "GWDecoder", "GWEncoder", + "GWEncoderLinear", "VariationalGWEncoder", - "GWInterfaceBase", - "GWInterface", - "VariationalGWInterface", "GWModuleBase", "GWModule", "VariationalGWModule", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 8a8e702e..138d8c1f 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -5,7 +5,7 @@ import torch from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig -from torch.nn import ModuleDict +from torch.nn import Module, ModuleDict from torch.optim.lr_scheduler import OneCycleLR from shimmer.modules.contrastive_loss import ( @@ -16,7 +16,6 @@ ) from shimmer.modules.domain import DomainModule from shimmer.modules.gw_module import ( - GWInterfaceBase, GWModule, GWModuleBase, GWModuleFusion, @@ -87,7 +86,6 @@ class GlobalWorkspaceBase(LightningModule): def __init__( self, gw_mod: GWModuleBase, - domain_mods: Mapping[str, DomainModule], loss_mod: GWLossesBase, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, @@ -97,9 +95,6 @@ def __init__( Args: gw_mod (`GWModuleBase`): the GWModule - domain_mods (`Mapping[str, DomainModule]`): mapping of the domains - connected to the GW. Keys are domain names, values are the - `DomainModule`. loss_mod (`GWLossesBase`): module to compute the GW losses. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay @@ -114,14 +109,14 @@ def __init__( "loss_mod", "domain_descriptions", "contrastive_loss", - "gw_interfaces", + "gw_encoders", + "gw_decoders", ] ) self.gw_mod = gw_mod """ a `GWModuleBase` implementation.""" - self.domain_mods = domain_mods - """Mapping of `DomainModule`s.""" + self.loss_mod = loss_mod """The module that computes losses of the GW""" @@ -131,6 +126,10 @@ def __init__( if scheduler_args is not None: self.scheduler_args.update(scheduler_args) + @property + def domain_mods(self) -> Mapping[str, DomainModule]: + return self.gw_mod.domain_mods + @property def workspace_dim(self) -> int: """Dimension of the GW.""" @@ -511,7 +510,8 @@ class GlobalWorkspace(GlobalWorkspaceBase): def __init__( self, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, GWInterfaceBase], + gw_encoders: Mapping[str, Module], + gw_decoders: Mapping[str, Module], workspace_dim: int, loss_coefs: LossCoefs, optim_lr: float = 1e-3, @@ -526,9 +526,12 @@ def __init__( domain_mods (`Mapping[str, DomainModule]`): mapping of the domains connected to the GW. Keys are domain names, values are the `DomainModule`. - gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain - name to a `GWInterfaceBase` class which role is to encode/decode + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to encode a unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to decode a + GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. loss_coefs (`LossCoefs`): loss coefficients optim_lr (`float`): learning rate @@ -540,8 +543,14 @@ def __init__( function used for alignment. `learn_logit_scale` will not affect custom contrastive losses. """ - gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) + + gw_mod = GWModule( + domain_mods, + workspace_dim, + gw_encoders, + gw_decoders, + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale @@ -555,7 +564,6 @@ def __init__( super().__init__( gw_mod, - domain_mods, loss_mod, optim_lr, optim_weight_decay, @@ -573,7 +581,8 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase): def __init__( self, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, GWInterfaceBase], + gw_encoders: Mapping[str, Module], + gw_decoders: Mapping[str, Module], workspace_dim: int, loss_coefs: VariationalLossCoefs, use_var_contrastive_loss: bool = False, @@ -590,9 +599,12 @@ def __init__( domain_mods (`Mapping[str, DomainModule]`): mapping of the domains connected to the GW. Keys are domain names, values are the `DomainModule`. - gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain - name to a `GWInterfaceBase` class which role is to encode/decode + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to encode a unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to decode a + GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. loss_coefs (`LossCoefs`): loss coefficients use_var_contrastive_loss (`bool`): whether to use the variational @@ -609,9 +621,15 @@ def __init__( contrastive loss. Only used if `use_var_contrastive_loss` is set to `True`. """ - gw_mod = VariationalGWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) + gw_mod = VariationalGWModule( + domain_mods, + workspace_dim, + gw_encoders, + gw_decoders, + ) + if use_var_contrastive_loss: if var_contrastive_loss is None: var_contrastive_loss = ContrastiveLossWithUncertainty( @@ -637,7 +655,6 @@ def __init__( super().__init__( gw_mod, - domain_mods, loss_mod, optim_lr, optim_weight_decay, @@ -655,7 +672,8 @@ class GlobalWorkspaceFusion(GlobalWorkspaceBase): def __init__( self, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, GWInterfaceBase], + gw_encoders: Mapping[str, Module], + gw_decoders: Mapping[str, Module], workspace_dim: int, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, @@ -669,9 +687,12 @@ def __init__( domain_mods (`Mapping[str, DomainModule]`): mapping of the domains connected to the GW. Keys are domain names, values are the `DomainModule`. - gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain - name to a `GWInterfaceBase` class which role is to encode/decode + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to encode a unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to decode a + GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay @@ -682,8 +703,13 @@ def __init__( function used for alignment. `learn_logit_scale` will not affect custom contrastive losses. """ - gw_mod = GWModuleFusion(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) + gw_mod = GWModuleFusion( + domain_mods, + workspace_dim, + gw_encoders, + gw_decoders, + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( @@ -697,7 +723,6 @@ def __init__( super().__init__( gw_mod, - domain_mods, loss_mod, optim_lr, optim_weight_decay, @@ -708,7 +733,8 @@ def __init__( def pretrained_global_workspace( checkpoint_path: str | Path, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, GWInterfaceBase], + gw_encoders: Mapping[str, Module], + gw_decoders: Mapping[str, Module], workspace_dim: int, loss_coefs: LossCoefs, contrastive_fn: ContrastiveLossType, @@ -722,9 +748,12 @@ def pretrained_global_workspace( domain_mods (`Mapping[str, DomainModule]`): mapping of the domains connected to the GW. Keys are domain names, values are the `DomainModule`. - gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain - name to a `GWInterfaceBase` class which role is to encode/decode + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to encode a unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a `torch.nn.Module` class which role is to decode a + GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. loss_coefs (`LossCoefs`): loss coefficients contrastive_loss (`ContrastiveLossType`): a contrastive loss @@ -739,8 +768,13 @@ def pretrained_global_workspace( Raises: `TypeError`: if loaded type is not `GlobalWorkspace`. """ - gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) + gw_mod = GWModule( + domain_mods, + workspace_dim, + gw_encoders, + gw_decoders, + ) loss_mod = GWLosses( gw_mod, domain_mods, @@ -750,7 +784,6 @@ def pretrained_global_workspace( gw = GlobalWorkspace.load_from_checkpoint( checkpoint_path, - domain_mods=domain_mods, gw_mod=gw_mod, loss_coefs=loss_coefs, loss_mod=loss_mod, diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 4165b161..4aebf7b6 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping -from typing import cast import torch from torch import nn @@ -28,7 +27,7 @@ def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]: class GWDecoder(nn.Sequential): - """A Decoder network used in GWInterfaces.""" + """A Decoder network for GWModules.""" def __init__( self, @@ -69,7 +68,7 @@ def __init__( class GWEncoder(GWDecoder): - """An Encoder network used in GWInterfaces. + """An Encoder network used in GWModules. This is similar to the decoder, but adds a tanh non-linearity at the end. """ @@ -96,6 +95,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.tanh(super().forward(input)) +class GWEncoderLinear(nn.Linear): + """A linear Encoder network used in GWModules.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.tanh(super().forward(input)) + + class VariationalGWEncoder(nn.Module): """A Variational flavor of encoder network used in GWInterfaces.""" @@ -144,64 +150,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return self.layers(x), self.uncertainty_level.expand(x.size(0), -1) -class GWInterfaceBase(nn.Module, ABC): - """Base class for GWInterfaces. - - Interfaces encode and decode unimodal representation to the domain's GW - representation (pre-fusion). - - This is an abstract class and should be implemented. - For an implemented interface, see `GWInterface`. - """ - - def __init__(self, domain_module: DomainModule, workspace_dim: int) -> None: - """ - Initialized the interface. - - Args: - domain_module (`DomainModule`): Domain module to link. - workspace_dim (`int`): dimension of the GW. - """ - super().__init__() - - self.domain_module = domain_module - """Domain module.""" - - self.workspace_dim = workspace_dim - """Dimension of the GW.""" - - @abstractmethod - def encode(self, x: torch.Tensor) -> torch.Tensor: - """ - Encode from the unimodal latent representation to the domain's GW - representation (pre-fusion). - - Args: - x (`torch.Tensor`): the domain's unimodal latent representation. - - Returns: - `torch.Tensor`: the domain's pre-fusion GW representation. - """ - ... - - @abstractmethod - def decode(self, z: torch.Tensor) -> torch.Tensor: - """ - Decode from the domain's pre-fusion GW to the unimodal latent representation. - - Args: - z (`torch.Tensor`): the domain's pre-fusion GW representation. - - Returns: - `torch.Tensor`: the domain's unimodal latent representation. - """ - ... - - class GWModuleBase(nn.Module, ABC): """Base class for GWModule. - GWModule handle how to merge representations from the Interfaces and define + GWModule handles encoding, decoding the unimodal representations + using the `gw_encoders` and`gw_decoders`, and define some common operations in GW like cycles and translations. This is an abstract class and should be implemented. @@ -209,26 +162,25 @@ class GWModuleBase(nn.Module, ABC): """ def __init__( - self, gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int + self, + domain_mods: Mapping[str, DomainModule], + workspace_dim: int, + *args, + **kwargs, ) -> None: """Initializes the GWModule. Args: - gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain - name to a `GWInterfaceBase` class which role is to encode/decode - unimodal latent representations into a GW representation (pre fusion). + domain_modules (`Mapping[str, DomainModule]`): the domain modules. workspace_dim (`int`): dimension of the GW. """ super().__init__() - # casting for LSP autocompletion - self.gw_interfaces = cast( - dict[str, GWInterfaceBase], nn.ModuleDict(gw_interfaces) - ) - """the GWInterface""" + self.domain_mods = domain_mods + """The unimodal domain modules.""" self.workspace_dim = workspace_dim - """dimension of the GW""" + """Dimension of the GW""" def on_before_gw_encode_dcy(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT: """ @@ -242,9 +194,7 @@ def on_before_gw_encode_dcy(self, x: LatentsDomainGroupT) -> LatentsDomainGroupD `LatentsDomainGroupDT`: the same mapping with updated representations """ return { - domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_dcy( - x[domain] - ) + domain: self.domain_mods[domain].on_before_gw_encode_dcy(x[domain]) for domain in x.keys() } @@ -260,9 +210,7 @@ def on_before_gw_encode_cy(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT `LatentsDomainGroupDT`: the same mapping with updated representations """ return { - domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_cy( - x[domain] - ) + domain: self.domain_mods[domain].on_before_gw_encode_cy(x[domain]) for domain in x.keys() } @@ -278,9 +226,7 @@ def on_before_gw_encode_tr(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT `LatentsDomainGroupDT`: the same mapping with updated representations """ return { - domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_tr( - x[domain] - ) + domain: self.domain_mods[domain].on_before_gw_encode_tr(x[domain]) for domain in x.keys() } @@ -296,9 +242,7 @@ def on_before_gw_encode_cont(self, x: LatentsDomainGroupT) -> LatentsDomainGroup `LatentsDomainGroupDT`: the same mapping with updated representations """ return { - domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_cont( - x[domain] - ) + domain: self.domain_mods[domain].on_before_gw_encode_cont(x[domain]) for domain in x.keys() } @@ -360,78 +304,35 @@ def cycle(self, x: LatentsDomainGroupT, through: str) -> LatentsDomainGroupDT: ... -class GWInterface(GWInterfaceBase): - """ - A implementation of `GWInterfaceBase` using `GWEncoder` and `GWDecoder`. - """ +class GWModule(GWModuleBase): + """GW nn.Module. Implements `GWModuleBase`.""" def __init__( self, - domain_module: DomainModule, + domain_modules: Mapping[str, DomainModule], workspace_dim: int, - encoder_hidden_dim: int, - encoder_n_layers: int, - decoder_hidden_dim: int, - decoder_n_layers: int, + gw_encoders: Mapping[str, nn.Module], + gw_decoders: Mapping[str, nn.Module], ) -> None: - """ - Initialized the interface. + """Initializes the GWModule. Args: - domain_module (`DomainModule`): Domain module to link. + domain_modules (`Mapping[str, DomainModule]`): the domain modules. workspace_dim (`int`): dimension of the GW. - encoder_hidden_dim (`int`): `hidden_dim` used for `GWEncoder`. - encoder_n_layers (`int`): `n_layers` used for `GWEncoder`. - decoder_hidden_dim (`int`): `hidden_dim` used for `GWDecoder`. - decoder_n_layers (`int`): `n_layers` used for `GWDecoder`. - """ - super().__init__(domain_module, workspace_dim) - - self.encoder = GWEncoder( - domain_module.latent_dim, - encoder_hidden_dim, - workspace_dim, - encoder_n_layers, - ) - """The interface encoder""" - - self.decoder = GWDecoder( - workspace_dim, - decoder_hidden_dim, - domain_module.latent_dim, - decoder_n_layers, - ) - """The interface decoder""" - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """ - Encode the unimodal latent representation to the domain's pre-fusion GW - representation. - - Args: - x (`torch.Tensor`): the unimodal latent representation. - - Returns: - `torch.Tensor`: the domain's pre-fusion GW representation. - """ - return self.encoder(x) - - def decode(self, z: torch.Tensor) -> torch.Tensor: - """ - Decode from the domain's pre-fusion GW - representation to the unimodal latent representation. - - Args: - z (`torch.Tensor`): the domain's pre-fusion GW representation. - - Returns: - `torch.Tensor`: the unimodal latent representation. + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a an torch.nn.Module class that encodes a + unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a an torch.nn.Module class that decodes a + GW representation to a unimodal latent representation. """ - return self.decoder(z) + super().__init__(domain_modules, workspace_dim) + self.gw_encoders = nn.ModuleDict(gw_encoders) + """The module's encoders""" -class GWModule(GWModuleBase): - """ """ + self.gw_decoders = nn.ModuleDict(gw_decoders) + """The module's decoders""" def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor: """ @@ -455,10 +356,7 @@ def encode(self, x: LatentsDomainGroupT) -> torch.Tensor: """ return self.fusion_mechanism( - { - domain: self.gw_interfaces[domain].encode(x[domain]) - for domain in x.keys() - } + {domain: self.gw_encoders[domain](x[domain]) for domain in x.keys()} ) def decode( @@ -474,8 +372,8 @@ def decode( `LatentsDomainGroupDT`: decoded unimodal representation """ return { - domain: self.gw_interfaces[domain].decode(z) - for domain in domains or self.gw_interfaces.keys() + domain: self.gw_decoders[domain](z) + for domain in domains or self.gw_decoders.keys() } def translate(self, x: LatentsDomainGroupT, to: str) -> torch.Tensor: @@ -509,67 +407,35 @@ def cycle(self, x: LatentsDomainGroupT, through: str) -> LatentsDomainGroupDT: } -class VariationalGWInterface(GWInterfaceBase): - """Variational flavor of `GWInterface`.""" +class VariationalGWModule(GWModuleBase): + """Variational flavor of `GWModule`.""" def __init__( self, - domain_module: DomainModule, + domain_modules: Mapping[str, DomainModule], workspace_dim: int, - encoder_hidden_dim: int, - encoder_n_layers: int, - decoder_hidden_dim: int, - decoder_n_layers: int, - ): - """Initializes the variational GWInterface. - - Args: - domain_module (`DomainModule`): domain module of the interface - workspace_dim (`int`): dimension of the GW - encoder_hidden_dim (`int`): `hidden_dim` of for `VariationalGWEncoder` - encoder_n_layers (`int`): `n_layers` for `VariationalGWEncoder` - decoder_hidden_dim (`int`): `hidden_dim` of for `GWDecoder` - decoder_n_layers (`int`): `n_layers` for `GWDecoder` - """ - - super().__init__(domain_module, workspace_dim) - - self.encoder = VariationalGWEncoder( - domain_module.latent_dim, - encoder_hidden_dim, - workspace_dim, - encoder_n_layers, - ) - self.decoder = GWDecoder( - workspace_dim, - decoder_hidden_dim, - domain_module.latent_dim, - decoder_n_layers, - ) - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode a unimodal representation into the pre-fusion GW representation. - - Args: - x (`torch.Tensor`): unimodal latent representation - Returns: - `torch.Tensor`: pre-fusion GW representation - """ - return self.encoder(x) - - def decode(self, z: torch.Tensor) -> torch.Tensor: - """Decode a GW representation into a unimodal representation. + gw_encoders: Mapping[str, nn.Module], + gw_decoders: Mapping[str, nn.Module], + ) -> None: + """Initializes the VariationalGWModule. Args: - z (`torch.Tensor`): GW representation. - Returns: - `torch.Tensor`: unimodal latent representation. + domain_modules (`Mapping[str, DomainModule]`): the domain modules. + workspace_dim (`int`): dimension of the GW. + gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a an torch.nn.Module class that encodes a + unimodal latent representations into a GW representation (pre fusion). + gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain + name to a an torch.nn.Module class that decodes a + GW representation to a unimodal latent representation. """ - return self.decoder(z) + super().__init__(domain_modules, workspace_dim) + self.gw_encoders = nn.ModuleDict(gw_encoders) + """The module's encoders""" -class VariationalGWModule(GWModuleBase): - """Variational flavor of `GWModule`.""" + self.gw_decoders = nn.ModuleDict(gw_decoders) + """The module's decoders""" def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor: """Fusion of the pre-fusion GW representations. @@ -596,7 +462,7 @@ def encode( """ latents: LatentsDomainGroupDT = {} for domain in x.keys(): - mean, log_uncertainty = self.gw_interfaces[domain].encode(x[domain]) + mean, log_uncertainty = self.gw_encoders[domain](x[domain]) latents[domain] = reparameterize(mean, log_uncertainty) return self.fusion_mechanism(latents) @@ -617,7 +483,7 @@ def encoded_distribution( means: LatentsDomainGroupDT = {} log_uncertainties: LatentsDomainGroupDT = {} for domain in x.keys(): - mean, log_uncertainty = self.gw_interfaces[domain].encode(x[domain]) + mean, log_uncertainty = self.gw_encoders[domain](x[domain]) means[domain] = mean log_uncertainties[domain] = log_uncertainty return means, log_uncertainties @@ -652,8 +518,8 @@ def decode( `LatentsDomainGroupDT`: decoded unimodal representations. """ return { - domain: self.gw_interfaces[domain].decode(z) - for domain in domains or self.gw_interfaces.keys() + domain: self.gw_decoders[domain](z) + for domain in domains or self.gw_decoders.keys() } def translate(self, x: LatentsDomainGroupT, to: str) -> torch.Tensor: @@ -711,16 +577,16 @@ def encode(self, x: LatentsDomainGroupT) -> torch.Tensor: domains = {} bs = group_batch_size(x) device = group_device(x) - for domain in self.gw_interfaces.keys(): + for domain in self.domain_mods.keys(): if domain in x: domains[domain] = x[domain] else: domains[domain] = torch.zeros( - bs, self.gw_interfaces[domain].domain_module.latent_dim + bs, self.domain_mods[domain].latent_dim ).to(device) return self.fusion_mechanism( { - domain_name: self.gw_interfaces[domain_name].encode(domain) + domain_name: self.gw_encoders[domain_name](domain) for domain_name, domain in domains.items() } ) diff --git a/tests/test_training.py b/tests/test_training.py index fa3648ba..2668fa27 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,7 +1,7 @@ import torch.utils.data from utils import DummyData, DummyDataset, DummyDomainModule -from shimmer import GlobalWorkspace, GWInterface +from shimmer import GlobalWorkspace, GWDecoder, GWEncoder def test_training(): @@ -18,36 +18,52 @@ def test_training(): workspace_dim = 16 - gw_interfaces = { - "v": GWInterface( - domains["v"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + gw_encoders = { + "v": GWEncoder( + domains["v"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, ), - "t": GWInterface( - domains["t"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + "t": GWEncoder( + domains["t"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, ), - "a": GWInterface( - domains["a"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + "a": GWEncoder( + domains["a"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, + ), + } + + gw_decoders = { + "v": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["v"].latent_dim, + n_layers=1, + ), + "t": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["t"].latent_dim, + n_layers=1, + ), + "a": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["a"].latent_dim, + n_layers=1, ), } gw = GlobalWorkspace( domains, - gw_interfaces, + gw_encoders, + gw_decoders, workspace_dim=16, loss_coefs={}, )