diff --git a/shimmer/__init__.py b/shimmer/__init__.py index cfc082e9..3f3e7bdb 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -30,12 +30,20 @@ from shimmer.modules.losses import ( GWLosses, GWLossesBase, - LatentsDomainGroupT, - LatentsT, LossCoefs, VariationalGWLosses, VariationalLossCoefs, ) +from shimmer.types import ( + LatentsDomainGroupDT, + LatentsDomainGroupsDT, + LatentsDomainGroupsT, + LatentsDomainGroupT, + RawDomainGroupDT, + RawDomainGroupsDT, + RawDomainGroupsT, + RawDomainGroupT, +) from shimmer.version import __version__ __all__ = [ @@ -60,10 +68,8 @@ "ContrastiveLossWithUncertainty", "contrastive_loss", "contrastive_loss_with_uncertainty", - "LatentsT", "LossCoefs", "VariationalLossCoefs", - "LatentsDomainGroupT", "GWLosses", "GWLossesBase", "VariationalGWLosses", @@ -74,4 +80,12 @@ "GWPredictions", "pretrained_global_workspace", "RepeatedDataset", + "LatentsDomainGroupDT", + "LatentsDomainGroupsDT", + "LatentsDomainGroupsT", + "LatentsDomainGroupT", + "RawDomainGroupDT", + "RawDomainGroupsDT", + "RawDomainGroupsT", + "RawDomainGroupT", ] diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 3191fbf2..29801681 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -29,8 +29,8 @@ from shimmer.modules.losses import ( GWLosses, GWLossesBase, + LatentsDomainGroupsT, LatentsDomainGroupT, - LatentsT, LossCoefs, VariationalGWLosses, VariationalLossCoefs, @@ -55,7 +55,7 @@ "ContrastiveLossWithUncertainty", "contrastive_loss", "contrastive_loss_with_uncertainty", - "LatentsT", + "LatentsDomainGroupsT", "LossCoefs", "VariationalLossCoefs", "LatentsDomainGroupT", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 3035c859..ce1eef1a 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -26,15 +26,17 @@ GWLosses, GWLossesBase, GWLossesFusion, + LossCoefs, + VariationalGWLosses, + VariationalLossCoefs, +) +from shimmer.types import ( LatentsDomainGroupsDT, LatentsDomainGroupsT, LatentsDomainGroupT, - LossCoefs, RawDomainGroupsDT, RawDomainGroupsT, RawDomainGroupT, - VariationalGWLosses, - VariationalLossCoefs, ) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 9cab1dfa..ffdb5322 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -6,8 +6,8 @@ from torch import nn from shimmer.modules.domain import DomainModule -from shimmer.modules.losses import LatentsDomainGroupT from shimmer.modules.vae import reparameterize +from shimmer.types import LatentsDomainGroupT def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]: @@ -381,6 +381,17 @@ def __init__( decoder_hidden_dim: int, decoder_n_layers: int, ) -> None: + """ + Initialized the interface. + + Args: + domain_module (`DomainModule`): Domain module to link. + workspace_dim (`int`): dimension of the GW. + encoder_hidden_dim (`int`): `hidden_dim` used for `GWEncoder`. + encoder_n_layers (`int`): `n_layers` used for `GWEncoder`. + decoder_hidden_dim (`int`): `hidden_dim` used for `GWDecoder`. + decoder_n_layers (`int`): `n_layers` used for `GWDecoder`. + """ super().__init__(domain_module, workspace_dim) self.encoder = GWEncoder( @@ -389,12 +400,15 @@ def __init__( workspace_dim, encoder_n_layers, ) + """The interface encoder""" + self.decoder = GWDecoder( workspace_dim, decoder_hidden_dim, domain_module.latent_dim, decoder_n_layers, ) + """The interface decoder""" def encode(self, x: torch.Tensor) -> torch.Tensor: return self.encoder(x) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 88d835e4..22f70f2f 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping -from typing import Any, TypedDict +from typing import TypedDict import torch import torch.nn.functional as F @@ -9,47 +8,7 @@ from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.gw_module import GWModule, GWModuleBase, VariationalGWModule from shimmer.modules.vae import kl_divergence_loss - -RawDomainGroupT = Mapping[str, Any] -"""Matched raw unimodal data from multiple domains. -Keys of the mapping are domains names.""" - -RawDomainGroupDT = dict[str, Any] -"""Matched raw unimodal data from multiple domains. -Keys of the dict are domains names. - -This is a more specific version of `RawDomainGroupT` used in method's outputs.""" - -LatentsDomainGroupT = Mapping[str, torch.Tensor] -"""Matched unimodal latent representations from multiple domains. -Keys of the mapping are domains names.""" - -LatentsDomainGroupDT = dict[str, torch.Tensor] -"""Matched unimodal latent representations from multiple domains. -Keys of the dict are domains names. - -This is a more specific version of `LatentsDomainGroupT` used in method's outputs.""" - -LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT] -"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired).""" - -LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT] -"""Mapping of `LatentsDomainGroupDT`. -Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired). - -This is a more specific version of `LatentsDomainGroupsT` used in method's outputs.""" - -RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT] -"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired).""" - -RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT] -"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired). - -This is a more specific version of `RawDomainGroupsT` used in method's outputs.""" +from shimmer.types import LatentsDomainGroupsT class GWLossesBase(torch.nn.Module, ABC): diff --git a/shimmer/types.py b/shimmer/types.py new file mode 100644 index 00000000..f6c41621 --- /dev/null +++ b/shimmer/types.py @@ -0,0 +1,45 @@ +from collections.abc import Mapping +from typing import Any + +import torch + +RawDomainGroupT = Mapping[str, Any] +"""Matched raw unimodal data from multiple domains. +Keys of the mapping are domains names.""" + +RawDomainGroupDT = dict[str, Any] +"""Matched raw unimodal data from multiple domains. +Keys of the dict are domains names. + +This is a more specific version of `RawDomainGroupT` used in method's outputs.""" + +LatentsDomainGroupT = Mapping[str, torch.Tensor] +"""Matched unimodal latent representations from multiple domains. +Keys of the mapping are domains names.""" + +LatentsDomainGroupDT = dict[str, torch.Tensor] +"""Matched unimodal latent representations from multiple domains. +Keys of the dict are domains names. + +This is a more specific version of `LatentsDomainGroupT` used in method's outputs.""" + +LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT] +"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).""" + +LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT] +"""Mapping of `LatentsDomainGroupDT`. +Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). + +This is a more specific version of `LatentsDomainGroupsT` used in method's outputs.""" + +RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT] +"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).""" + +RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT] +"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). + +This is a more specific version of `RawDomainGroupsT` used in method's outputs."""