From 74c8c8555117222e965f8eaa172535c9a27ba990 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 1 Mar 2024 17:17:28 +0000 Subject: [PATCH] Continue docs WIP --- shimmer/modules/domain.py | 14 ++++----- shimmer/modules/global_workspace.py | 6 ++-- shimmer/modules/gw_module.py | 34 +++++++++++++++++--- shimmer/types.py | 48 +++++++++++++++++++---------- 4 files changed, 72 insertions(+), 30 deletions(-) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index f92a4706..6d0d96c1 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -9,14 +9,14 @@ class LossOutput: """This is a python dataclass use as a returned value for losses. It keeps track of what is used for training (`loss`) and what is used - only for logging (`metrics`) + only for logging (`metrics`). """ loss: torch.Tensor - """Loss used during training""" + """Loss used during training.""" metrics: dict[str, torch.Tensor] = field(default_factory=dict) - """Some additional metrics to log (not used during training)""" + """Some additional metrics to log (not used during training).""" def __post_init__(self): if "loss" in self.metrics.keys(): @@ -25,7 +25,7 @@ def __post_init__(self): @property def all(self) -> dict[str, torch.Tensor]: """ - Returns a dict with all metrics and loss with "loss" key + Returns a dict with all metrics and loss with "loss" key. """ return {**self.metrics, "loss": self.loss} @@ -34,9 +34,9 @@ class DomainModule(pl.LightningModule): """ Base class for a DomainModule that defines domain specific modules of the GW. - > [!NOTE] - > We do not use ABC here because some modules could - > be without encore or decoder. + .. note:: + We do not use ABC here because some modules could + be without encore or decoder. """ def __init__( diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index ce1eef1a..3a0edc87 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -476,9 +476,9 @@ def freeze_domain_modules( ) -> dict[str, DomainModule]: """Freezes weights and set to eval mode the domain modules. - > [!NOTE] - > The output is casted as `dict[str, DomainModule]` type for better auto-completion, - > but is actually a torch `ModuleDict`. + .. note:: + The output is casted as `dict[str, DomainModule]` type for better + auto-completion, but is actually a torch `ModuleDict`. Args: domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index ea59f027..501501d8 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -440,7 +440,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: class GWModule(GWModuleBase): """ """ - def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor: """ Merge function used to combine domains. @@ -451,7 +451,17 @@ def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: """ return torch.mean(torch.stack(list(x.values())), dim=0) - def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + def encode(self, x: LatentsDomainGroupT) -> torch.Tensor: + """ + Encode the unimodal latent representation `x` into the GW representation + + Args: + x (`LatentsDomainGroupT`) + + Returns: + `torch.Tensor` + + """ return self.fusion_mechanism( { domain: self.gw_interfaces[domain].encode(x[domain]) @@ -461,13 +471,29 @@ def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: def decode( self, z: torch.Tensor, domains: Iterable[str] | None = None - ) -> dict[str, torch.Tensor]: + ) -> LatentsDomainGroupT: + """Decodes a GW representation to multiple domains. + + Args: + z: the GW representation + domains: the domains to decode to. If not given, will use + keys in `gw_interfaces` (all domains). + """ return { domain: self.gw_interfaces[domain].decode(z) for domain in domains or self.gw_interfaces.keys() } - def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor: + def translate(self, x: LatentsDomainGroupT, to: str) -> torch.Tensor: + """Translate from multiple domains to one domain. + + Args: + x: the group of latent representations + to: the domain name to encode to + + Returns: + the translated unimodal representation of the provided domain. + """ return self.decode(self.encode(x), domains={to})[to] def cycle( diff --git a/shimmer/types.py b/shimmer/types.py index f6c41621..ea61e806 100644 --- a/shimmer/types.py +++ b/shimmer/types.py @@ -4,42 +4,58 @@ import torch RawDomainGroupT = Mapping[str, Any] -"""Matched raw unimodal data from multiple domains. -Keys of the mapping are domains names.""" +""" +Matched raw unimodal data from multiple domains. +Keys of the mapping are domains names. +""" RawDomainGroupDT = dict[str, Any] -"""Matched raw unimodal data from multiple domains. +""" +Matched raw unimodal data from multiple domains. Keys of the dict are domains names. -This is a more specific version of `RawDomainGroupT` used in method's outputs.""" +This is a more specific version of `RawDomainGroupT` used in method's outputs. +""" LatentsDomainGroupT = Mapping[str, torch.Tensor] -"""Matched unimodal latent representations from multiple domains. -Keys of the mapping are domains names.""" +""" +Matched unimodal latent representations from multiple domains. +Keys of the mapping are domains names. +""" LatentsDomainGroupDT = dict[str, torch.Tensor] -"""Matched unimodal latent representations from multiple domains. +""" +Matched unimodal latent representations from multiple domains. Keys of the dict are domains names. -This is a more specific version of `LatentsDomainGroupT` used in method's outputs.""" +This is a more specific version of `LatentsDomainGroupT` used in method's outputs. +""" LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT] -"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired).""" +""" +Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). +""" LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT] -"""Mapping of `LatentsDomainGroupDT`. +""" +Mapping of `LatentsDomainGroupDT`. Keys are frozenset of domains matched in the group. Each group is independent and contains different data (unpaired). -This is a more specific version of `LatentsDomainGroupsT` used in method's outputs.""" +This is a more specific version of `LatentsDomainGroupsT` used in method's outputs. +""" RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT] -"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. -Each group is independent and contains different data (unpaired).""" +""" +Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). +""" RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT] -"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +""" +Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. Each group is independent and contains different data (unpaired). -This is a more specific version of `RawDomainGroupsT` used in method's outputs.""" +This is a more specific version of `RawDomainGroupsT` used in method's outputs. +"""