Skip to content

Commit

Permalink
Continue docs WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 1, 2024
1 parent 14fe1bd commit 74c8c85
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 30 deletions.
14 changes: 7 additions & 7 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
class LossOutput:
"""This is a python dataclass use as a returned value for losses.
It keeps track of what is used for training (`loss`) and what is used
only for logging (`metrics`)
only for logging (`metrics`).
"""

loss: torch.Tensor
"""Loss used during training"""
"""Loss used during training."""

metrics: dict[str, torch.Tensor] = field(default_factory=dict)
"""Some additional metrics to log (not used during training)"""
"""Some additional metrics to log (not used during training)."""

def __post_init__(self):
if "loss" in self.metrics.keys():
Expand All @@ -25,7 +25,7 @@ def __post_init__(self):
@property
def all(self) -> dict[str, torch.Tensor]:
"""
Returns a dict with all metrics and loss with "loss" key
Returns a dict with all metrics and loss with "loss" key.
"""
return {**self.metrics, "loss": self.loss}

Expand All @@ -34,9 +34,9 @@ class DomainModule(pl.LightningModule):
"""
Base class for a DomainModule that defines domain specific modules of the GW.
> [!NOTE]
> We do not use ABC here because some modules could
> be without encore or decoder.
.. note::
We do not use ABC here because some modules could
be without encore or decoder.
"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def freeze_domain_modules(
) -> dict[str, DomainModule]:
"""Freezes weights and set to eval mode the domain modules.
> [!NOTE]
> The output is casted as `dict[str, DomainModule]` type for better auto-completion,
> but is actually a torch `ModuleDict`.
.. note::
The output is casted as `dict[str, DomainModule]` type for better
auto-completion, but is actually a torch `ModuleDict`.
Args:
domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
Expand Down
34 changes: 30 additions & 4 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
class GWModule(GWModuleBase):
""" """

def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor:
"""
Merge function used to combine domains.
Expand All @@ -451,7 +451,17 @@ def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
"""
return torch.mean(torch.stack(list(x.values())), dim=0)

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
def encode(self, x: LatentsDomainGroupT) -> torch.Tensor:
"""
Encode the unimodal latent representation `x` into the GW representation
Args:
x (`LatentsDomainGroupT`)
Returns:
`torch.Tensor`
"""
return self.fusion_mechanism(
{
domain: self.gw_interfaces[domain].encode(x[domain])
Expand All @@ -461,13 +471,29 @@ def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:

def decode(
self, z: torch.Tensor, domains: Iterable[str] | None = None
) -> dict[str, torch.Tensor]:
) -> LatentsDomainGroupT:
"""Decodes a GW representation to multiple domains.
Args:
z: the GW representation
domains: the domains to decode to. If not given, will use
keys in `gw_interfaces` (all domains).
"""
return {
domain: self.gw_interfaces[domain].decode(z)
for domain in domains or self.gw_interfaces.keys()
}

def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor:
def translate(self, x: LatentsDomainGroupT, to: str) -> torch.Tensor:
"""Translate from multiple domains to one domain.
Args:
x: the group of latent representations
to: the domain name to encode to
Returns:
the translated unimodal representation of the provided domain.
"""
return self.decode(self.encode(x), domains={to})[to]

def cycle(
Expand Down
48 changes: 32 additions & 16 deletions shimmer/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,58 @@
import torch

RawDomainGroupT = Mapping[str, Any]
"""Matched raw unimodal data from multiple domains.
Keys of the mapping are domains names."""
"""
Matched raw unimodal data from multiple domains.
Keys of the mapping are domains names.
"""

RawDomainGroupDT = dict[str, Any]
"""Matched raw unimodal data from multiple domains.
"""
Matched raw unimodal data from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `RawDomainGroupT` used in method's outputs."""
This is a more specific version of `RawDomainGroupT` used in method's outputs.
"""

LatentsDomainGroupT = Mapping[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
Keys of the mapping are domains names."""
"""
Matched unimodal latent representations from multiple domains.
Keys of the mapping are domains names.
"""

LatentsDomainGroupDT = dict[str, torch.Tensor]
"""Matched unimodal latent representations from multiple domains.
"""
Matched unimodal latent representations from multiple domains.
Keys of the dict are domains names.
This is a more specific version of `LatentsDomainGroupT` used in method's outputs."""
This is a more specific version of `LatentsDomainGroupT` used in method's outputs.
"""

LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT]
"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""
"""
Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
"""

LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT]
"""Mapping of `LatentsDomainGroupDT`.
"""
Mapping of `LatentsDomainGroupDT`.
Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `LatentsDomainGroupsT` used in method's outputs."""
This is a more specific version of `LatentsDomainGroupsT` used in method's outputs.
"""

RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired)."""
"""
Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
"""

RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT]
"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
"""
Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
Each group is independent and contains different data (unpaired).
This is a more specific version of `RawDomainGroupsT` used in method's outputs."""
This is a more specific version of `RawDomainGroupsT` used in method's outputs.
"""

0 comments on commit 74c8c85

Please sign in to comment.